Delete main
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- main/app/app.py +0 -0
- main/app/parser.py +0 -340
- main/app/tensorboard.py +0 -30
- main/configs/config.json +0 -547
- main/configs/config.py +0 -90
- main/configs/decrypt.bin +0 -3
- main/configs/v1/32000.json +0 -46
- main/configs/v1/40000.json +0 -46
- main/configs/v1/48000.json +0 -46
- main/configs/v2/32000.json +0 -42
- main/configs/v2/40000.json +0 -42
- main/configs/v2/48000.json +0 -42
- main/inference/audio_effects.py +0 -180
- main/inference/audioldm2.py +0 -210
- main/inference/convert.py +0 -590
- main/inference/create_dataset.py +0 -230
- main/inference/create_index.py +0 -90
- main/inference/extract.py +0 -360
- main/inference/preprocess.py +0 -270
- main/inference/separator_music.py +0 -310
- main/inference/train.py +0 -990
- main/library/algorithm/commons.py +0 -60
- main/library/algorithm/modules.py +0 -60
- main/library/algorithm/mrf_hifigan.py +0 -150
- main/library/algorithm/onnx_export.py +0 -50
- main/library/algorithm/refinegan.py +0 -170
- main/library/algorithm/residuals.py +0 -140
- main/library/algorithm/separator.py +0 -320
- main/library/algorithm/stftpitchshift.py +0 -250
- main/library/algorithm/synthesizers.py +0 -490
- main/library/architectures/demucs_separator.py +0 -180
- main/library/architectures/fairseq.py +0 -1480
- main/library/architectures/mdx_separator.py +0 -320
- main/library/audioldm2/models.py +0 -330
- main/library/audioldm2/utils.py +0 -40
- main/library/predictors/CREPE.py +0 -210
- main/library/predictors/FCPE.py +0 -1097
- main/library/predictors/RMVPE.py +0 -260
- main/library/predictors/SWIPE.py +0 -140
- main/library/predictors/WORLD_WRAPPER.py +0 -90
- main/library/speaker_diarization/ECAPA_TDNN.py +0 -280
- main/library/speaker_diarization/audio.py +0 -170
- main/library/speaker_diarization/embedding.py +0 -90
- main/library/speaker_diarization/encoder.py +0 -250
- main/library/speaker_diarization/features.py +0 -520
- main/library/speaker_diarization/parameter_transfer.py +0 -120
- main/library/speaker_diarization/segment.py +0 -540
- main/library/speaker_diarization/speechbrain.py +0 -220
- main/library/speaker_diarization/whisper.py +0 -1290
- main/library/utils.py +0 -240
main/app/app.py
DELETED
The diff for this file is too large to render.
See raw diff
|
|
main/app/parser.py
DELETED
@@ -1,340 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
|
4 |
-
sys.path.append(os.getcwd())
|
5 |
-
|
6 |
-
try:
|
7 |
-
argv = sys.argv[1]
|
8 |
-
except IndexError:
|
9 |
-
argv = None
|
10 |
-
|
11 |
-
argv_is_allows = ["--audio_effects", "--audioldm2", "--convert", "--create_dataset", "--create_index", "--extract", "--preprocess", "--separator_music", "--train", "--help_audio_effects", "--help_audioldm2", "--help_convert", "--help_create_dataset", "--help_create_index", "--help_extract", "--help_preprocess", "--help_separator_music", "--help_train", "--help"]
|
12 |
-
|
13 |
-
if argv not in argv_is_allows:
|
14 |
-
print("Cú pháp không hợp lệ! Sử dụng --help để biết thêm")
|
15 |
-
quit()
|
16 |
-
|
17 |
-
if argv_is_allows[0] in argv: from main.inference.audio_effects import main
|
18 |
-
elif argv_is_allows[1] in argv: from main.inference.audioldm2 import main
|
19 |
-
elif argv_is_allows[2] in argv: from main.inference.convert import main
|
20 |
-
elif argv_is_allows[3] in argv: from main.inference.create_dataset import main
|
21 |
-
elif argv_is_allows[4] in argv: from main.inference.create_index import main
|
22 |
-
elif argv_is_allows[5] in argv: from main.inference.extract import main
|
23 |
-
elif argv_is_allows[6] in argv: from main.inference.preprocess import main
|
24 |
-
elif argv_is_allows[7] in argv: from main.inference.separator_music import main
|
25 |
-
elif argv_is_allows[8] in argv: from main.inference.train import main
|
26 |
-
elif argv_is_allows[9] in argv:
|
27 |
-
print("""Các tham số của `--audio_effects`:
|
28 |
-
1. Đường dẫn tệp:
|
29 |
-
- `--input_path` (bắt buộc): Đường dẫn đến tệp âm thanh đầu vào.
|
30 |
-
- `--output_path` (mặc định: `./audios/apply_effects.wav`): Đường dẫn lưu tệp đầu ra.
|
31 |
-
- `--export_format` (mặc định: `wav`): Định dạng xuất tệp (`wav`, `mp3`, ...).
|
32 |
-
|
33 |
-
2. Lấy mẫu lại:
|
34 |
-
- `--resample` (mặc định: `False`): Có lấy mẫu lại hay không.
|
35 |
-
- `--resample_sr` (mặc định: `0`): Tần số lấy mẫu mới (Hz).
|
36 |
-
|
37 |
-
3. Hiệu ứng chorus:
|
38 |
-
- `--chorus`: Bật/tắt chorus.
|
39 |
-
- `--chorus_depth`, `--chorus_rate`, `--chorus_mix`, `--chorus_delay`, `--chorus_feedback`: Các thông số điều chỉnh chorus.
|
40 |
-
|
41 |
-
4. Hiệu ứng distortion:
|
42 |
-
- `--distortion`: Bật/tắt distortion.
|
43 |
-
- `--drive_db`: Mức độ méo âm thanh.
|
44 |
-
|
45 |
-
5. Hiệu ứng reverb:
|
46 |
-
- `--reverb`: Bật/tắt hồi âm.
|
47 |
-
- `--reverb_room_size`, `--reverb_damping`, `--reverb_wet_level`, `--reverb_dry_level`, `--reverb_width`, `--reverb_freeze_mode`: Điều chỉnh hồi âm.
|
48 |
-
|
49 |
-
6. Hiệu ứng pitch shift:
|
50 |
-
- `--pitchshift`: Bật/tắt thay đổi cao độ.
|
51 |
-
- `--pitch_shift`: Giá trị dịch cao độ.
|
52 |
-
|
53 |
-
7. Hiệu ứng delay:
|
54 |
-
- `--delay`: Bật/tắt delay.
|
55 |
-
- `--delay_seconds`, `--delay_feedback`, `--delay_mix`: Điều chỉnh thời gian trễ, phản hồi và hòa trộn.
|
56 |
-
|
57 |
-
8. Compressor:
|
58 |
-
- `--compressor`: Bật/tắt compressor.
|
59 |
-
- `--compressor_threshold`, `--compressor_ratio`, `--compressor_attack_ms`, `--compressor_release_ms`: Các thông số nén.
|
60 |
-
|
61 |
-
9. Limiter:
|
62 |
-
- `--limiter`: Bật/tắt giới hạn mức âm thanh.
|
63 |
-
- `--limiter_threshold`, `--limiter_release`: Ngưỡng giới hạn và thời gian nhả.
|
64 |
-
|
65 |
-
10. Gain (Khuếch đại):
|
66 |
-
- `--gain`: Bật/tắt gain.
|
67 |
-
- `--gain_db`: Mức gain (dB).
|
68 |
-
|
69 |
-
11. Bitcrush:
|
70 |
-
- `--bitcrush`: Bật/tắt hiệu ứng giảm độ phân giải.
|
71 |
-
- `--bitcrush_bit_depth`: Số bit của bitcrush.
|
72 |
-
|
73 |
-
12. Clipping:
|
74 |
-
- `--clipping`: Bật/tắt cắt âm thanh.
|
75 |
-
- `--clipping_threshold`: Ngưỡng clipping.
|
76 |
-
|
77 |
-
13. Phaser:
|
78 |
-
- `--phaser`: Bật/tắt hiệu ứng phaser.
|
79 |
-
- `--phaser_rate_hz`, `--phaser_depth`, `--phaser_centre_frequency_hz`, `--phaser_feedback`, `--phaser_mix`: Điều chỉnh hiệu ứng phaser.
|
80 |
-
|
81 |
-
14. Boost bass & treble:
|
82 |
-
- `--treble_bass_boost`: Bật/tắt tăng cường âm bass và treble.
|
83 |
-
- `--bass_boost_db`, `--bass_boost_frequency`, `--treble_boost_db`, `--treble_boost_frequency`: Các thông số tăng bass và treble.
|
84 |
-
|
85 |
-
15. Fade in & fade out:
|
86 |
-
- `--fade_in_out`: Bật/tắt hiệu ứng fade.
|
87 |
-
- `--fade_in_duration`, `--fade_out_duration`: Thời gian fade vào/ra.
|
88 |
-
|
89 |
-
16. Kết hợp âm thanh:
|
90 |
-
- `--audio_combination`: Bật/tắt ghép nhiều tệp âm thanh.
|
91 |
-
- `--audio_combination_input`: Đường dẫn tệp âm thanh bổ sung.
|
92 |
-
""")
|
93 |
-
quit()
|
94 |
-
elif argv_is_allows[10] in argv:
|
95 |
-
print("""Các tham số của --audioldm2:
|
96 |
-
1. Đường dẫn tệp:
|
97 |
-
- `--input_path` (bắt buộc): Đường dẫn đến tệp âm thanh đầu vào.
|
98 |
-
- `--output_path` (mặc định: `./output.wav`): Đường dẫn lưu tệp đầu ra.
|
99 |
-
- `--export_format` (mặc định: `wav`): Định dạng xuất tệp.
|
100 |
-
|
101 |
-
2. Cấu hình âm thanh:
|
102 |
-
- `--sample_rate` (mặc định: `44100`): Tần số lấy mẫu (Hz).
|
103 |
-
|
104 |
-
3. Cấu hình mô hình AudioLDM:
|
105 |
-
- `--audioldm_model` (mặc định: `audioldm2-music`): Chọn mô hình AudioLDM để xử lý.
|
106 |
-
|
107 |
-
4. Prompt hướng dẫn mô hình:
|
108 |
-
- `--source_prompt` (mặc định: ``): Mô tả âm thanh nguồn.
|
109 |
-
- `--target_prompt` (mặc định: ``): Mô tả âm thanh đích.
|
110 |
-
|
111 |
-
5. Cấu hình thuật toán xử lý:
|
112 |
-
- `--steps` (mặc định: `200`): Số bước xử lý trong quá trình tổng hợp âm thanh.
|
113 |
-
- `--cfg_scale_src` (mặc định: `3.5`): Hệ số điều chỉnh hướng dẫn cho âm thanh nguồn.
|
114 |
-
- `--cfg_scale_tar` (mặc định: `12`): Hệ số điều chỉnh hướng dẫn cho âm thanh đích.
|
115 |
-
- `--t_start` (mặc định: `45`): Mức độ chỉnh sửa.
|
116 |
-
|
117 |
-
6. Tối ưu hóa tính toán:
|
118 |
-
- `--save_compute` (mặc định: `False`): Có bật chế độ tối ưu tính toán hay không.
|
119 |
-
""")
|
120 |
-
quit()
|
121 |
-
elif argv_is_allows[11] in argv:
|
122 |
-
print("""Các tham số của --convert:
|
123 |
-
1. Cấu hình xử lý giọng nói:
|
124 |
-
- `--pitch` (mặc định: `0`): Điều chỉnh cao độ.
|
125 |
-
- `--filter_radius` (mặc định: `3`): Độ mượt của đường F0.
|
126 |
-
- `--index_rate` (mặc định: `0.5`): Tỷ lệ sử dụng chỉ mục giọng nói.
|
127 |
-
- `--volume_envelope` (mặc định: `1`): Hệ số điều chỉnh biên độ âm lượng.
|
128 |
-
- `--protect` (mặc định: `0.33`): Bảo vệ phụ âm.
|
129 |
-
|
130 |
-
2. Cấu hình mẫu (frame hop):
|
131 |
-
- `--hop_length` (mặc định: `64`): Bước nhảy khi xử lý âm thanh.
|
132 |
-
|
133 |
-
3. Cấu hình F0:
|
134 |
-
- `--f0_method` (mặc định: `rmvpe`): Phương pháp dự đoán F0 (`pm`, `dio`, `mangio-crepe-tiny`, `mangio-crepe-small`, `mangio-crepe-medium`, `mangio-crepe-large`, `mangio-crepe-full`, `crepe-tiny`, `crepe-small`, `crepe-medium`, `crepe-large`, `crepe-full`, `fcpe`, `fcpe-legacy`, `rmvpe`, `rmvpe-legacy`, `harvest`, `yin`, `pyin`, `swipe`).
|
135 |
-
- `--f0_autotune` (mặc định: `False`): Có tự động điều chỉnh F0 hay không.
|
136 |
-
- `--f0_autotune_strength` (mặc định: `1`): Cường độ hiệu chỉnh tự động F0.
|
137 |
-
- `--f0_file` (mặc định: ``): Đường dẫn tệp F0 có sẵn.
|
138 |
-
- `--f0_onnx` (mặc định: `False`): Có sử dụng phiên bản ONNX của F0 hay không.
|
139 |
-
|
140 |
-
4. Mô hình nhúng:
|
141 |
-
- `--embedder_model` (mặc định: `contentvec_base`): Mô hình nhúng sử dụng.
|
142 |
-
- `--embedders_mode` (mặc định: `fairseq`): Chế độ nhúng (`fairseq`, `transformers`, `onnx`).
|
143 |
-
|
144 |
-
5. Đường dẫn tệp:
|
145 |
-
- `--input_path` (bắt buộc): Đường dẫn tệp âm thanh đầu vào.
|
146 |
-
- `--output_path` (mặc định: `./audios/output.wav`): Đường dẫn lưu tệp đầu ra.
|
147 |
-
- `--export_format` (mặc định: `wav`): Định dạng xuất tệp.
|
148 |
-
- `--pth_path` (bắt buộc): Đường dẫn đến tệp mô hình `.pth`.
|
149 |
-
- `--index_path` (mặc định: `None`): Đường dẫn tệp chỉ mục (nếu có).
|
150 |
-
|
151 |
-
6. Làm sạch âm thanh:
|
152 |
-
- `--clean_audio` (mặc định: `False`): Có áp dụng làm sạch âm thanh không.
|
153 |
-
- `--clean_strength` (mặc định: `0.7`): Mức độ làm sạch.
|
154 |
-
|
155 |
-
7. Resampling & chia nhỏ âm thanh:
|
156 |
-
- `--resample_sr` (mặc định: `0`): Tần số lấy mẫu mới (0 nghĩa là giữ nguyên).
|
157 |
-
- `--split_audio` (mặc định: `False`): Có chia nhỏ audio trước khi xử lý không.
|
158 |
-
|
159 |
-
8. Kiểm tra & tối ưu hóa:
|
160 |
-
- `--checkpointing` (mặc định: `False`): Bật/tắt checkpointing để tiết kiệm RAM.
|
161 |
-
|
162 |
-
9. Dịch formant:
|
163 |
-
- `--formant_shifting` (mặc định: `False`): Có bật hiệu ứng dịch formant không.
|
164 |
-
- `--formant_qfrency` (mặc định: `0.8`): Hệ số dịch formant theo tần số.
|
165 |
-
- `--formant_timbre` (mặc định: `0.8`): Hệ số thay đổi màu sắc giọng.
|
166 |
-
""")
|
167 |
-
quit()
|
168 |
-
elif argv_is_allows[12] in argv:
|
169 |
-
print("""Các tham số của --create_dataset:
|
170 |
-
1. Đường dẫn & cấu hình dataset:
|
171 |
-
- `--input_audio` (bắt buộc): Đường dẫn liên kết đến âm thanh (Liên kết Youtube, có thể dùng dấu `,` để dùng nhiều liên kết).
|
172 |
-
- `--output_dataset` (mặc định: `./dataset`): Thư mục xuất dữ liệu đầu ra.
|
173 |
-
- `--sample_rate` (mặc định: `44100`): Tần số lấy mẫu cho âm thanh.
|
174 |
-
|
175 |
-
2. Làm sạch dữ liệu:
|
176 |
-
- `--clean_dataset` (mặc định: `False`): Có áp dụng làm sạch dữ liệu hay không.
|
177 |
-
- `--clean_strength` (mặc định: `0.7`): Mức độ làm sạch dữ liệu.
|
178 |
-
|
179 |
-
3. Tách giọng & hiệu ứng:
|
180 |
-
- `--separator_reverb` (mặc định: `False`): Có tách vang giọng không.
|
181 |
-
- `--kim_vocal_version` (mặc định: `2`): Phiên bản mô hình Kim Vocal để tách (`1`, `2`).
|
182 |
-
|
183 |
-
4. Cấu hình phân đoạn âm thanh:
|
184 |
-
- `--overlap` (mặc định: `0.25`): Mức độ chồng lấn giữa các đoạn khi tách.
|
185 |
-
- `--segments_size` (mặc định: `256`): Kích thước của từng phân đoạn.
|
186 |
-
|
187 |
-
5. Cấu hình MDX (Music Demixing):
|
188 |
-
- `--mdx_hop_length` (mặc định: `1024`): Bước nhảy MDX khi xử lý.
|
189 |
-
- `--mdx_batch_size` (mặc định: `1`): Kích thước batch khi xử lý MDX.
|
190 |
-
- `--denoise_mdx` (mặc định: `False`): Có áp dụng khử nhiễu khi tách bằng MDX không.
|
191 |
-
|
192 |
-
6. Bỏ qua phần âm thanh:
|
193 |
-
- `--skip` (mặc định: `False`): Có bỏ qua giây âm thanh nào không.
|
194 |
-
- `--skip_start_audios` (mặc định: `0`): Thời gian (giây) cần bỏ qua ở đầu audio.
|
195 |
-
- `--skip_end_audios` (mặc định: `0`): Thời gian (giây) cần bỏ qua ở cuối audio.
|
196 |
-
""")
|
197 |
-
quit()
|
198 |
-
elif argv_is_allows[13] in argv:
|
199 |
-
print("""Các tham số của --create_index:
|
200 |
-
1. Thông tin mô hình:
|
201 |
-
- `--model_name` (bắt buộc): Tên mô hình.
|
202 |
-
- `--rvc_version` (mặc định: `v2`): Phiên bản (`v1`, `v2`).
|
203 |
-
- `--index_algorithm` (mặc định: `Auto`): Thuật toán index sử dụng (`Auto`, `Faiss`, `KMeans`).
|
204 |
-
""")
|
205 |
-
quit()
|
206 |
-
elif argv_is_allows[14] in argv:
|
207 |
-
print("""Các tham số của --extract:
|
208 |
-
1. Thông tin mô hình:
|
209 |
-
- `--model_name` (bắt buộc): Tên mô hình.
|
210 |
-
- `--rvc_version` (mặc định: `v2`): Phiên bản RVC (`v1`, `v2`).
|
211 |
-
|
212 |
-
2. Cấu hình F0:
|
213 |
-
- `--f0_method` (mặc định: `rmvpe`): Phương pháp dự đoán F0 (`pm`, `dio`, `mangio-crepe-tiny`, `mangio-crepe-small`, `mangio-crepe-medium`, `mangio-crepe-large`, `mangio-crepe-full`, `crepe-tiny`, `crepe-small`, `crepe-medium`, `crepe-large`, `crepe-full`, `fcpe`, `fcpe-legacy`, `rmvpe`, `rmvpe-legacy`, `harvest`, `yin`, `pyin`, `swipe`).
|
214 |
-
- `--pitch_guidance` (mặc định: `True`): Có sử dụng hướng dẫn cao độ hay không.
|
215 |
-
|
216 |
-
3. Cấu hình xử lý:
|
217 |
-
- `--hop_length` (mặc định: `128`): Độ dài bước nhảy trong quá trình xử lý.
|
218 |
-
- `--cpu_cores` (mặc định: `2`): Số lượng luồng CPU sử dụng.
|
219 |
-
- `--gpu` (mặc định: `-`): Chỉ định GPU sử dụng (ví dụ: `0` cho GPU đầu tiên, `-` để tắt GPU).
|
220 |
-
- `--sample_rate` (bắt buộc): Tần số lấy mẫu của âm thanh đầu vào.
|
221 |
-
|
222 |
-
4. Cấu hình nhúng:
|
223 |
-
- `--embedder_model` (mặc định: `contentvec_base`): Tên mô hình nhúng.
|
224 |
-
- `--f0_onnx` (mặc định: `False`): Có sử dụng phiên bản ONNX của F0 hay không.
|
225 |
-
- `--embedders_mode` (mặc định: `fairseq`): Chế độ nhúng (`fairseq`, `transformers`, `onnx`).
|
226 |
-
""")
|
227 |
-
quit()
|
228 |
-
elif argv_is_allows[15] in argv:
|
229 |
-
print("""Các tham số của --preprocess:
|
230 |
-
1. Thông tin mô hình:
|
231 |
-
- `--model_name` (bắt buộc): Tên mô hình.
|
232 |
-
|
233 |
-
2. Cấu hình dữ liệu:
|
234 |
-
- `--dataset_path` (mặc định: `./dataset`): Đường dẫn thư mục chứa tệp dữ liệu.
|
235 |
-
- `--sample_rate` (bắt buộc): Tần số lấy mẫu của dữ liệu âm thanh.
|
236 |
-
|
237 |
-
3. Cấu hình xử lý:
|
238 |
-
- `--cpu_cores` (mặc định: `2`): Số lượng luồng CPU sử dụng.
|
239 |
-
- `--cut_preprocess` (mặc định: `True`): Có cắt tệp dữ liệu hay không.
|
240 |
-
- `--process_effects` (mặc định: `False`): Có áp dụng tiền xử lý hay không.
|
241 |
-
- `--clean_dataset` (mặc định: `False`): Có làm sạch tệp dữ liệu hay không.
|
242 |
-
- `--clean_strength` (mặc định: `0.7`): Độ mạnh của quá trình làm sạch dữ liệu.
|
243 |
-
""")
|
244 |
-
quit()
|
245 |
-
elif argv_is_allows[16] in argv:
|
246 |
-
print("""Các tham số của --separator_music:
|
247 |
-
1. Đường dẫn dữ liệu:
|
248 |
-
- `--input_path` (bắt buộc): Đường dẫn tệp âm thanh đầu vào.
|
249 |
-
- `--output_path` (mặc định: `./audios`): Thư mục lưu tệp đầu ra.
|
250 |
-
- `--format` (mặc định: `wav`): Định dạng xuất tệp (`wav`, `mp3`,...).
|
251 |
-
|
252 |
-
2. Cấu hình xử lý âm thanh:
|
253 |
-
- `--shifts` (m���c định: `2`): Số lượng dự đoán.
|
254 |
-
- `--segments_size` (mặc định: `256`): Kích thước phân đoạn âm thanh.
|
255 |
-
- `--overlap` (mặc định: `0.25`): Mức độ chồng lấn giữa các đoạn.
|
256 |
-
- `--mdx_hop_length` (mặc định: `1024`): Bước nhảy MDX khi xử lý.
|
257 |
-
- `--mdx_batch_size` (mặc định: `1`): Kích thước lô.
|
258 |
-
|
259 |
-
3. Xử lý làm sạch:
|
260 |
-
- `--clean_audio` (mặc định: `False`): Có làm sạch âm thanh hay không.
|
261 |
-
- `--clean_strength` (mặc định: `0.7`): Độ mạnh của bộ lọc làm sạch.
|
262 |
-
|
263 |
-
4. Cấu hình mô hình:
|
264 |
-
- `--model_name` (mặc định: `HT-Normal`): Mô hình tách nhạc (`Main_340`, `Main_390`, `Main_406`, `Main_427`, `Main_438`, `Inst_full_292`, `Inst_HQ_1`, `Inst_HQ_2`, `Inst_HQ_3`, `Inst_HQ_4`, `Inst_HQ_5`, `Kim_Vocal_1`, `Kim_Vocal_2`, `Kim_Inst`, `Inst_187_beta`, `Inst_82_beta`, `Inst_90_beta`, `Voc_FT`, `Crowd_HQ`, `Inst_1`, `Inst_2`, `Inst_3`, `MDXNET_1_9703`, `MDXNET_2_9682`, `MDXNET_3_9662`, `Inst_Main`, `MDXNET_Main`, `MDXNET_9482`, `HT-Normal`, `HT-Tuned`, `HD_MMI`, `HT_6S`).
|
265 |
-
- `--kara_model` (mặc định: `Version-1`): Phiên bản mô hình tách bè (`Version-1`, `Version-2`).
|
266 |
-
|
267 |
-
5. Hiệu ứng và xử lý hậu kỳ:
|
268 |
-
- `--backing` (mặc định: `False`): Có tách bè hay không.
|
269 |
-
- `--mdx_denoise` (mặc định: `False`): Có sử dụng khử nhiễu MDX hay không.
|
270 |
-
- `--reverb` (mặc định: `False`): Có tách vang hay không.
|
271 |
-
- `--backing_reverb` (mặc định: `False`): có tách vang cho giọng bè không.
|
272 |
-
|
273 |
-
6. Tần số lấy mẫu:
|
274 |
-
- `--sample_rate` (mặc định: `44100`): Tần số lấy mẫu của âm thanh đầu ra.
|
275 |
-
""")
|
276 |
-
quit()
|
277 |
-
elif argv_is_allows[17] in argv:
|
278 |
-
print("""Các tham số của --train:
|
279 |
-
1. Cấu hình mô hình:
|
280 |
-
- `--model_name` (bắt buộc): Tên mô hình.
|
281 |
-
- `--rvc_version` (mặc định: `v2`): Phiên bản RVC (`v1`, `v2`).
|
282 |
-
- `--model_author` (tùy chọn): Tác giả của mô hình.
|
283 |
-
|
284 |
-
2. Cấu hình lưu:
|
285 |
-
- `--save_every_epoch` (bắt buộc): Số kỷ nguyên giữa mỗi lần lưu.
|
286 |
-
- `--save_only_latest` (mặc định: `True`): Chỉ lưu điểm mới nhất.
|
287 |
-
- `--save_every_weights` (mặc định: `True`): Lưu tất cả trọng số của mô hình.
|
288 |
-
|
289 |
-
3. Cấu hình huấn luyện:
|
290 |
-
- `--total_epoch` (mặc định: `300`): Tổng số kỷ nguyên huấn luyện.
|
291 |
-
- `--batch_size` (mặc định: `8`): Kích thước lô trong quá trình huấn luyện.
|
292 |
-
- `--sample_rate` (bắt buộc): Tần số lấy mẫu của âm thanh.
|
293 |
-
|
294 |
-
4. Cấu hình thiết bị:
|
295 |
-
- `--gpu` (mặc định: `0`): Chỉ định GPU để sử dụng (số hoặc `-` nếu không dùng GPU).
|
296 |
-
- `--cache_data_in_gpu` (mặc định: `False`): Lưu dữ liệu vào GPU để tăng tốc.
|
297 |
-
|
298 |
-
5. Cấu hình huấn luyện nâng cao:
|
299 |
-
- `--pitch_guidance` (mặc định: `True`): Sử dụng hướng dẫn cao độ.
|
300 |
-
- `--g_pretrained_path` (mặc định: ``): Đường dẫn đến trọng số G đã huấn luyện trước.
|
301 |
-
- `--d_pretrained_path` (mặc định: ``): Đường dẫn đến trọng số D đã huấn luyện trước.
|
302 |
-
- `--vocoder` (mặc định: `Default`): Bộ mã hóa được sử dụng (`Default`, `MRF-HiFi-GAN`, `RefineGAN`).
|
303 |
-
|
304 |
-
6. Phát hiện huấn luyện quá mức:
|
305 |
-
- `--overtraining_detector` (mặc định: `False`): Bật/tắt chế độ phát hiện huấn luyện quá mức.
|
306 |
-
- `--overtraining_threshold` (mặc định: `50`): Ngưỡng để xác định huấn luyện quá mức.
|
307 |
-
|
308 |
-
7. Xử lý dữ liệu:
|
309 |
-
- `--cleanup` (mặc định: `False`): Dọn dẹp tệp huấn luyện cũ để tiến hành huấn luyện lại từ đầu.
|
310 |
-
|
311 |
-
8. Tối ưu:
|
312 |
-
- `--checkpointing` (mặc định: `False`): Bật/tắt checkpointing để tiết kiệm RAM.
|
313 |
-
- `--deterministic` (mặc định: `False`): Khi bật sẽ sử dụng các thuật toán có tính xác định cao, đảm bảo rằng mỗi lần chạy cùng một dữ liệu đầu vào sẽ cho kết quả giống nhau.
|
314 |
-
- `--benchmark` (mặc định: `False`): Khi bật sẽ thử nghiệm và chọn thuật toán tối ưu nhất cho phần cứng và kích thước cụ thể.
|
315 |
-
""")
|
316 |
-
quit()
|
317 |
-
elif argv_is_allows[18] in argv:
|
318 |
-
print("""Sử dụng:
|
319 |
-
1. `--help_audio_effects`: Trợ giúp về phần thêm hiệu ứng âm thanh.
|
320 |
-
2. `--help_audioldm2`: Trợ giúp về phần chỉnh sửa nhạc.
|
321 |
-
3. `--help_convert`: Trợ giúp về chuyển đổi âm thanh.
|
322 |
-
4. `--help_create_dataset`: Trợ giúp về tạo dữ liệu huấn luyện.
|
323 |
-
5. `--help_create_index`: Trợ giúp về tạo chỉ mục.
|
324 |
-
6. `--help_extract`: Trợ giúp về trích xuất dữ liệu huấn luyện.
|
325 |
-
7. `--help_preprocess`: Trợ giúp về xử lý trước dữ liệu.
|
326 |
-
8. `--help_separator_music`: Trợ giúp về tách nhạc.
|
327 |
-
9. `--help_train`: Trợ giúp về huấn luyện mô hình.
|
328 |
-
""")
|
329 |
-
quit()
|
330 |
-
|
331 |
-
|
332 |
-
if __name__ == "__main__":
|
333 |
-
if "--train" in argv:
|
334 |
-
import torch.multiprocessing as mp
|
335 |
-
mp.set_start_method("spawn")
|
336 |
-
|
337 |
-
try:
|
338 |
-
main()
|
339 |
-
except:
|
340 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/app/tensorboard.py
DELETED
@@ -1,30 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import json
|
4 |
-
import logging
|
5 |
-
import webbrowser
|
6 |
-
|
7 |
-
from tensorboard import program
|
8 |
-
|
9 |
-
sys.path.append(os.getcwd())
|
10 |
-
|
11 |
-
from main.configs.config import Config
|
12 |
-
translations = Config().translations
|
13 |
-
|
14 |
-
with open(os.path.join("main", "configs", "config.json"), "r") as f:
|
15 |
-
configs = json.load(f)
|
16 |
-
|
17 |
-
def launch_tensorboard():
|
18 |
-
for l in ["root", "tensorboard"]:
|
19 |
-
logging.getLogger(l).setLevel(logging.ERROR)
|
20 |
-
|
21 |
-
tb = program.TensorBoard()
|
22 |
-
tb.configure(argv=[None, "--logdir", "assets/logs", f"--port={configs['tensorboard_port']}"])
|
23 |
-
url = tb.launch()
|
24 |
-
|
25 |
-
print(f"{translations['tensorboard_url']}: {url}")
|
26 |
-
if "--open" in sys.argv: webbrowser.open(url)
|
27 |
-
|
28 |
-
return f"{translations['tensorboard_url']}: {url}"
|
29 |
-
|
30 |
-
if __name__ == "__main__": launch_tensorboard()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/config.json
DELETED
@@ -1,547 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"language": "vi-VN",
|
3 |
-
"support_language": [
|
4 |
-
"en-US",
|
5 |
-
"vi-VN"
|
6 |
-
],
|
7 |
-
"theme": "NoCrypt/miku",
|
8 |
-
"themes": [
|
9 |
-
"NoCrypt/miku",
|
10 |
-
"gstaff/xkcd",
|
11 |
-
"JohnSmith9982/small_and_pretty",
|
12 |
-
"ParityError/Interstellar",
|
13 |
-
"earneleh/paris",
|
14 |
-
"shivi/calm_seafoam",
|
15 |
-
"Hev832/Applio",
|
16 |
-
"YTheme/Minecraft",
|
17 |
-
"gstaff/sketch",
|
18 |
-
"SebastianBravo/simci_css",
|
19 |
-
"allenai/gradio-theme",
|
20 |
-
"Nymbo/Nymbo_Theme_5",
|
21 |
-
"lone17/kotaemon",
|
22 |
-
"Zarkel/IBM_Carbon_Theme",
|
23 |
-
"SherlockRamos/Feliz",
|
24 |
-
"freddyaboulton/dracula_revamped",
|
25 |
-
"freddyaboulton/bad-theme-space",
|
26 |
-
"gradio/dracula_revamped",
|
27 |
-
"abidlabs/dracula_revamped",
|
28 |
-
"gradio/dracula_test",
|
29 |
-
"gradio/seafoam",
|
30 |
-
"gradio/glass",
|
31 |
-
"gradio/monochrome",
|
32 |
-
"gradio/soft",
|
33 |
-
"gradio/default",
|
34 |
-
"gradio/base",
|
35 |
-
"abidlabs/pakistan",
|
36 |
-
"dawood/microsoft_windows",
|
37 |
-
"ysharma/steampunk",
|
38 |
-
"ysharma/huggingface",
|
39 |
-
"abidlabs/Lime",
|
40 |
-
"freddyaboulton/this-theme-does-not-exist-2",
|
41 |
-
"aliabid94/new-theme",
|
42 |
-
"aliabid94/test2",
|
43 |
-
"aliabid94/test3",
|
44 |
-
"aliabid94/test4",
|
45 |
-
"abidlabs/banana",
|
46 |
-
"freddyaboulton/test-blue",
|
47 |
-
"gstaff/whiteboard",
|
48 |
-
"ysharma/llamas",
|
49 |
-
"abidlabs/font-test",
|
50 |
-
"YenLai/Superhuman",
|
51 |
-
"bethecloud/storj_theme",
|
52 |
-
"sudeepshouche/minimalist",
|
53 |
-
"knotdgaf/gradiotest",
|
54 |
-
"ParityError/Anime",
|
55 |
-
"Ajaxon6255/Emerald_Isle",
|
56 |
-
"ParityError/LimeFace",
|
57 |
-
"finlaymacklon/smooth_slate",
|
58 |
-
"finlaymacklon/boxy_violet",
|
59 |
-
"derekzen/stardust",
|
60 |
-
"EveryPizza/Cartoony-Gradio-Theme",
|
61 |
-
"Ifeanyi/Cyanister",
|
62 |
-
"Tshackelton/IBMPlex-DenseReadable",
|
63 |
-
"snehilsanyal/scikit-learn",
|
64 |
-
"Himhimhim/xkcd",
|
65 |
-
"nota-ai/theme",
|
66 |
-
"rawrsor1/Everforest",
|
67 |
-
"rottenlittlecreature/Moon_Goblin",
|
68 |
-
"abidlabs/test-yellow",
|
69 |
-
"abidlabs/test-yellow3",
|
70 |
-
"idspicQstitho/dracula_revamped",
|
71 |
-
"kfahn/AnimalPose",
|
72 |
-
"HaleyCH/HaleyCH_Theme",
|
73 |
-
"simulKitke/dracula_test",
|
74 |
-
"braintacles/CrimsonNight",
|
75 |
-
"wentaohe/whiteboardv2",
|
76 |
-
"reilnuud/polite",
|
77 |
-
"remilia/Ghostly",
|
78 |
-
"Franklisi/darkmode",
|
79 |
-
"coding-alt/soft",
|
80 |
-
"xiaobaiyuan/theme_land",
|
81 |
-
"step-3-profit/Midnight-Deep",
|
82 |
-
"xiaobaiyuan/theme_demo",
|
83 |
-
"Taithrah/Minimal",
|
84 |
-
"Insuz/SimpleIndigo",
|
85 |
-
"zkunn/Alipay_Gradio_theme",
|
86 |
-
"Insuz/Mocha",
|
87 |
-
"xiaobaiyuan/theme_brief",
|
88 |
-
"Ama434/434-base-Barlow",
|
89 |
-
"Ama434/def_barlow",
|
90 |
-
"Ama434/neutral-barlow",
|
91 |
-
"dawood/dracula_test",
|
92 |
-
"nuttea/Softblue",
|
93 |
-
"BlueDancer/Alien_Diffusion",
|
94 |
-
"naughtondale/monochrome",
|
95 |
-
"Dagfinn1962/standard",
|
96 |
-
"default"
|
97 |
-
],
|
98 |
-
"mdx_model": [
|
99 |
-
"Main_340",
|
100 |
-
"Main_390",
|
101 |
-
"Main_406",
|
102 |
-
"Main_427",
|
103 |
-
"Main_438",
|
104 |
-
"Inst_full_292",
|
105 |
-
"Inst_HQ_1",
|
106 |
-
"Inst_HQ_2",
|
107 |
-
"Inst_HQ_3",
|
108 |
-
"Inst_HQ_4",
|
109 |
-
"Inst_HQ_5",
|
110 |
-
"Kim_Vocal_1",
|
111 |
-
"Kim_Vocal_2",
|
112 |
-
"Kim_Inst",
|
113 |
-
"Inst_187_beta",
|
114 |
-
"Inst_82_beta",
|
115 |
-
"Inst_90_beta",
|
116 |
-
"Voc_FT",
|
117 |
-
"Crowd_HQ",
|
118 |
-
"Inst_1",
|
119 |
-
"Inst_2",
|
120 |
-
"Inst_3",
|
121 |
-
"MDXNET_1_9703",
|
122 |
-
"MDXNET_2_9682",
|
123 |
-
"MDXNET_3_9662",
|
124 |
-
"Inst_Main",
|
125 |
-
"MDXNET_Main",
|
126 |
-
"MDXNET_9482"
|
127 |
-
],
|
128 |
-
"demucs_model": [
|
129 |
-
"HT-Normal",
|
130 |
-
"HT-Tuned",
|
131 |
-
"HD_MMI",
|
132 |
-
"HT_6S"
|
133 |
-
],
|
134 |
-
"edge_tts": [
|
135 |
-
"af-ZA-AdriNeural",
|
136 |
-
"af-ZA-WillemNeural",
|
137 |
-
"sq-AL-AnilaNeural",
|
138 |
-
"sq-AL-IlirNeural",
|
139 |
-
"am-ET-AmehaNeural",
|
140 |
-
"am-ET-MekdesNeural",
|
141 |
-
"ar-DZ-AminaNeural",
|
142 |
-
"ar-DZ-IsmaelNeural",
|
143 |
-
"ar-BH-AliNeural",
|
144 |
-
"ar-BH-LailaNeural",
|
145 |
-
"ar-EG-SalmaNeural",
|
146 |
-
"ar-EG-ShakirNeural",
|
147 |
-
"ar-IQ-BasselNeural",
|
148 |
-
"ar-IQ-RanaNeural",
|
149 |
-
"ar-JO-SanaNeural",
|
150 |
-
"ar-JO-TaimNeural",
|
151 |
-
"ar-KW-FahedNeural",
|
152 |
-
"ar-KW-NouraNeural",
|
153 |
-
"ar-LB-LaylaNeural",
|
154 |
-
"ar-LB-RamiNeural",
|
155 |
-
"ar-LY-ImanNeural",
|
156 |
-
"ar-LY-OmarNeural",
|
157 |
-
"ar-MA-JamalNeural",
|
158 |
-
"ar-MA-MounaNeural",
|
159 |
-
"ar-OM-AbdullahNeural",
|
160 |
-
"ar-OM-AyshaNeural",
|
161 |
-
"ar-QA-AmalNeural",
|
162 |
-
"ar-QA-MoazNeural",
|
163 |
-
"ar-SA-HamedNeural",
|
164 |
-
"ar-SA-ZariyahNeural",
|
165 |
-
"ar-SY-AmanyNeural",
|
166 |
-
"ar-SY-LaithNeural",
|
167 |
-
"ar-TN-HediNeural",
|
168 |
-
"ar-TN-ReemNeural",
|
169 |
-
"ar-AE-FatimaNeural",
|
170 |
-
"ar-AE-HamdanNeural",
|
171 |
-
"ar-YE-MaryamNeural",
|
172 |
-
"ar-YE-SalehNeural",
|
173 |
-
"az-AZ-BabekNeural",
|
174 |
-
"az-AZ-BanuNeural",
|
175 |
-
"bn-BD-NabanitaNeural",
|
176 |
-
"bn-BD-PradeepNeural",
|
177 |
-
"bn-IN-BashkarNeural",
|
178 |
-
"bn-IN-TanishaaNeural",
|
179 |
-
"bs-BA-GoranNeural",
|
180 |
-
"bs-BA-VesnaNeural",
|
181 |
-
"bg-BG-BorislavNeural",
|
182 |
-
"bg-BG-KalinaNeural",
|
183 |
-
"my-MM-NilarNeural",
|
184 |
-
"my-MM-ThihaNeural",
|
185 |
-
"ca-ES-EnricNeural",
|
186 |
-
"ca-ES-JoanaNeural",
|
187 |
-
"zh-HK-HiuGaaiNeural",
|
188 |
-
"zh-HK-HiuMaanNeural",
|
189 |
-
"zh-HK-WanLungNeural",
|
190 |
-
"zh-CN-XiaoxiaoNeural",
|
191 |
-
"zh-CN-XiaoyiNeural",
|
192 |
-
"zh-CN-YunjianNeural",
|
193 |
-
"zh-CN-YunxiNeural",
|
194 |
-
"zh-CN-YunxiaNeural",
|
195 |
-
"zh-CN-YunyangNeural",
|
196 |
-
"zh-CN-liaoning-XiaobeiNeural",
|
197 |
-
"zh-TW-HsiaoChenNeural",
|
198 |
-
"zh-TW-YunJheNeural",
|
199 |
-
"zh-TW-HsiaoYuNeural",
|
200 |
-
"zh-CN-shaanxi-XiaoniNeural",
|
201 |
-
"hr-HR-GabrijelaNeural",
|
202 |
-
"hr-HR-SreckoNeural",
|
203 |
-
"cs-CZ-AntoninNeural",
|
204 |
-
"cs-CZ-VlastaNeural",
|
205 |
-
"da-DK-ChristelNeural",
|
206 |
-
"da-DK-JeppeNeural",
|
207 |
-
"nl-BE-ArnaudNeural",
|
208 |
-
"nl-BE-DenaNeural",
|
209 |
-
"nl-NL-ColetteNeural",
|
210 |
-
"nl-NL-FennaNeural",
|
211 |
-
"nl-NL-MaartenNeural",
|
212 |
-
"en-AU-NatashaNeural",
|
213 |
-
"en-AU-WilliamNeural",
|
214 |
-
"en-CA-ClaraNeural",
|
215 |
-
"en-CA-LiamNeural",
|
216 |
-
"en-HK-SamNeural",
|
217 |
-
"en-HK-YanNeural",
|
218 |
-
"en-IN-NeerjaExpressiveNeural",
|
219 |
-
"en-IN-NeerjaNeural",
|
220 |
-
"en-IN-PrabhatNeural",
|
221 |
-
"en-IE-ConnorNeural",
|
222 |
-
"en-IE-EmilyNeural",
|
223 |
-
"en-KE-AsiliaNeural",
|
224 |
-
"en-KE-ChilembaNeural",
|
225 |
-
"en-NZ-MitchellNeural",
|
226 |
-
"en-NZ-MollyNeural",
|
227 |
-
"en-NG-AbeoNeural",
|
228 |
-
"en-NG-EzinneNeural",
|
229 |
-
"en-PH-JamesNeural",
|
230 |
-
"en-PH-RosaNeural",
|
231 |
-
"en-SG-LunaNeural",
|
232 |
-
"en-SG-WayneNeural",
|
233 |
-
"en-ZA-LeahNeural",
|
234 |
-
"en-ZA-LukeNeural",
|
235 |
-
"en-TZ-ElimuNeural",
|
236 |
-
"en-TZ-ImaniNeural",
|
237 |
-
"en-GB-LibbyNeural",
|
238 |
-
"en-GB-MaisieNeural",
|
239 |
-
"en-GB-RyanNeural",
|
240 |
-
"en-GB-SoniaNeural",
|
241 |
-
"en-GB-ThomasNeural",
|
242 |
-
"en-US-AvaMultilingualNeural",
|
243 |
-
"en-US-AndrewMultilingualNeural",
|
244 |
-
"en-US-EmmaMultilingualNeural",
|
245 |
-
"en-US-BrianMultilingualNeural",
|
246 |
-
"en-US-AvaNeural",
|
247 |
-
"en-US-AndrewNeural",
|
248 |
-
"en-US-EmmaNeural",
|
249 |
-
"en-US-BrianNeural",
|
250 |
-
"en-US-AnaNeural",
|
251 |
-
"en-US-AriaNeural",
|
252 |
-
"en-US-ChristopherNeural",
|
253 |
-
"en-US-EricNeural",
|
254 |
-
"en-US-GuyNeural",
|
255 |
-
"en-US-JennyNeural",
|
256 |
-
"en-US-MichelleNeural",
|
257 |
-
"en-US-RogerNeural",
|
258 |
-
"en-US-SteffanNeural",
|
259 |
-
"et-EE-AnuNeural",
|
260 |
-
"et-EE-KertNeural",
|
261 |
-
"fil-PH-AngeloNeural",
|
262 |
-
"fil-PH-BlessicaNeural",
|
263 |
-
"fi-FI-HarriNeural",
|
264 |
-
"fi-FI-NooraNeural",
|
265 |
-
"fr-BE-CharlineNeural",
|
266 |
-
"fr-BE-GerardNeural",
|
267 |
-
"fr-CA-ThierryNeural",
|
268 |
-
"fr-CA-AntoineNeural",
|
269 |
-
"fr-CA-JeanNeural",
|
270 |
-
"fr-CA-SylvieNeural",
|
271 |
-
"fr-FR-VivienneMultilingualNeural",
|
272 |
-
"fr-FR-RemyMultilingualNeural",
|
273 |
-
"fr-FR-DeniseNeural",
|
274 |
-
"fr-FR-EloiseNeural",
|
275 |
-
"fr-FR-HenriNeural",
|
276 |
-
"fr-CH-ArianeNeural",
|
277 |
-
"fr-CH-FabriceNeural",
|
278 |
-
"gl-ES-RoiNeural",
|
279 |
-
"gl-ES-SabelaNeural",
|
280 |
-
"ka-GE-EkaNeural",
|
281 |
-
"ka-GE-GiorgiNeural",
|
282 |
-
"de-AT-IngridNeural",
|
283 |
-
"de-AT-JonasNeural",
|
284 |
-
"de-DE-SeraphinaMultilingualNeural",
|
285 |
-
"de-DE-FlorianMultilingualNeural",
|
286 |
-
"de-DE-AmalaNeural",
|
287 |
-
"de-DE-ConradNeural",
|
288 |
-
"de-DE-KatjaNeural",
|
289 |
-
"de-DE-KillianNeural",
|
290 |
-
"de-CH-JanNeural",
|
291 |
-
"de-CH-LeniNeural",
|
292 |
-
"el-GR-AthinaNeural",
|
293 |
-
"el-GR-NestorasNeural",
|
294 |
-
"gu-IN-DhwaniNeural",
|
295 |
-
"gu-IN-NiranjanNeural",
|
296 |
-
"he-IL-AvriNeural",
|
297 |
-
"he-IL-HilaNeural",
|
298 |
-
"hi-IN-MadhurNeural",
|
299 |
-
"hi-IN-SwaraNeural",
|
300 |
-
"hu-HU-NoemiNeural",
|
301 |
-
"hu-HU-TamasNeural",
|
302 |
-
"is-IS-GudrunNeural",
|
303 |
-
"is-IS-GunnarNeural",
|
304 |
-
"id-ID-ArdiNeural",
|
305 |
-
"id-ID-GadisNeural",
|
306 |
-
"ga-IE-ColmNeural",
|
307 |
-
"ga-IE-OrlaNeural",
|
308 |
-
"it-IT-GiuseppeNeural",
|
309 |
-
"it-IT-DiegoNeural",
|
310 |
-
"it-IT-ElsaNeural",
|
311 |
-
"it-IT-IsabellaNeural",
|
312 |
-
"ja-JP-KeitaNeural",
|
313 |
-
"ja-JP-NanamiNeural",
|
314 |
-
"jv-ID-DimasNeural",
|
315 |
-
"jv-ID-SitiNeural",
|
316 |
-
"kn-IN-GaganNeural",
|
317 |
-
"kn-IN-SapnaNeural",
|
318 |
-
"kk-KZ-AigulNeural",
|
319 |
-
"kk-KZ-DauletNeural",
|
320 |
-
"km-KH-PisethNeural",
|
321 |
-
"km-KH-SreymomNeural",
|
322 |
-
"ko-KR-HyunsuNeural",
|
323 |
-
"ko-KR-InJoonNeural",
|
324 |
-
"ko-KR-SunHiNeural",
|
325 |
-
"lo-LA-ChanthavongNeural",
|
326 |
-
"lo-LA-KeomanyNeural",
|
327 |
-
"lv-LV-EveritaNeural",
|
328 |
-
"lv-LV-NilsNeural",
|
329 |
-
"lt-LT-LeonasNeural",
|
330 |
-
"lt-LT-OnaNeural",
|
331 |
-
"mk-MK-AleksandarNeural",
|
332 |
-
"mk-MK-MarijaNeural",
|
333 |
-
"ms-MY-OsmanNeural",
|
334 |
-
"ms-MY-YasminNeural",
|
335 |
-
"ml-IN-MidhunNeural",
|
336 |
-
"ml-IN-SobhanaNeural",
|
337 |
-
"mt-MT-GraceNeural",
|
338 |
-
"mt-MT-JosephNeural",
|
339 |
-
"mr-IN-AarohiNeural",
|
340 |
-
"mr-IN-ManoharNeural",
|
341 |
-
"mn-MN-BataaNeural",
|
342 |
-
"mn-MN-YesuiNeural",
|
343 |
-
"ne-NP-HemkalaNeural",
|
344 |
-
"ne-NP-SagarNeural",
|
345 |
-
"nb-NO-FinnNeural",
|
346 |
-
"nb-NO-PernilleNeural",
|
347 |
-
"ps-AF-GulNawazNeural",
|
348 |
-
"ps-AF-LatifaNeural",
|
349 |
-
"fa-IR-DilaraNeural",
|
350 |
-
"fa-IR-FaridNeural",
|
351 |
-
"pl-PL-MarekNeural",
|
352 |
-
"pl-PL-ZofiaNeural",
|
353 |
-
"pt-BR-ThalitaNeural",
|
354 |
-
"pt-BR-AntonioNeural",
|
355 |
-
"pt-BR-FranciscaNeural",
|
356 |
-
"pt-PT-DuarteNeural",
|
357 |
-
"pt-PT-RaquelNeural",
|
358 |
-
"ro-RO-AlinaNeural",
|
359 |
-
"ro-RO-EmilNeural",
|
360 |
-
"ru-RU-DmitryNeural",
|
361 |
-
"ru-RU-SvetlanaNeural",
|
362 |
-
"sr-RS-NicholasNeural",
|
363 |
-
"sr-RS-SophieNeural",
|
364 |
-
"si-LK-SameeraNeural",
|
365 |
-
"si-LK-ThiliniNeural",
|
366 |
-
"sk-SK-LukasNeural",
|
367 |
-
"sk-SK-ViktoriaNeural",
|
368 |
-
"sl-SI-PetraNeural",
|
369 |
-
"sl-SI-RokNeural",
|
370 |
-
"so-SO-MuuseNeural",
|
371 |
-
"so-SO-UbaxNeural",
|
372 |
-
"es-AR-ElenaNeural",
|
373 |
-
"es-AR-TomasNeural",
|
374 |
-
"es-BO-MarceloNeural",
|
375 |
-
"es-BO-SofiaNeural",
|
376 |
-
"es-CL-CatalinaNeural",
|
377 |
-
"es-CL-LorenzoNeural",
|
378 |
-
"es-ES-XimenaNeural",
|
379 |
-
"es-CO-GonzaloNeural",
|
380 |
-
"es-CO-SalomeNeural",
|
381 |
-
"es-CR-JuanNeural",
|
382 |
-
"es-CR-MariaNeural",
|
383 |
-
"es-CU-BelkysNeural",
|
384 |
-
"es-CU-ManuelNeural",
|
385 |
-
"es-DO-EmilioNeural",
|
386 |
-
"es-DO-RamonaNeural",
|
387 |
-
"es-EC-AndreaNeural",
|
388 |
-
"es-EC-LuisNeural",
|
389 |
-
"es-SV-LorenaNeural",
|
390 |
-
"es-SV-RodrigoNeural",
|
391 |
-
"es-GQ-JavierNeural",
|
392 |
-
"es-GQ-TeresaNeural",
|
393 |
-
"es-GT-AndresNeural",
|
394 |
-
"es-GT-MartaNeural",
|
395 |
-
"es-HN-CarlosNeural",
|
396 |
-
"es-HN-KarlaNeural",
|
397 |
-
"es-MX-DaliaNeural",
|
398 |
-
"es-MX-JorgeNeural",
|
399 |
-
"es-NI-FedericoNeural",
|
400 |
-
"es-NI-YolandaNeural",
|
401 |
-
"es-PA-MargaritaNeural",
|
402 |
-
"es-PA-RobertoNeural",
|
403 |
-
"es-PY-MarioNeural",
|
404 |
-
"es-PY-TaniaNeural",
|
405 |
-
"es-PE-AlexNeural",
|
406 |
-
"es-PE-CamilaNeural",
|
407 |
-
"es-PR-KarinaNeural",
|
408 |
-
"es-PR-VictorNeural",
|
409 |
-
"es-ES-AlvaroNeural",
|
410 |
-
"es-ES-ElviraNeural",
|
411 |
-
"es-US-AlonsoNeural",
|
412 |
-
"es-US-PalomaNeural",
|
413 |
-
"es-UY-MateoNeural",
|
414 |
-
"es-UY-ValentinaNeural",
|
415 |
-
"es-VE-PaolaNeural",
|
416 |
-
"es-VE-SebastianNeural",
|
417 |
-
"su-ID-JajangNeural",
|
418 |
-
"su-ID-TutiNeural",
|
419 |
-
"sw-KE-RafikiNeural",
|
420 |
-
"sw-KE-ZuriNeural",
|
421 |
-
"sw-TZ-DaudiNeural",
|
422 |
-
"sw-TZ-RehemaNeural",
|
423 |
-
"sv-SE-MattiasNeural",
|
424 |
-
"sv-SE-SofieNeural",
|
425 |
-
"ta-IN-PallaviNeural",
|
426 |
-
"ta-IN-ValluvarNeural",
|
427 |
-
"ta-MY-KaniNeural",
|
428 |
-
"ta-MY-SuryaNeural",
|
429 |
-
"ta-SG-AnbuNeural",
|
430 |
-
"ta-SG-VenbaNeural",
|
431 |
-
"ta-LK-KumarNeural",
|
432 |
-
"ta-LK-SaranyaNeural",
|
433 |
-
"te-IN-MohanNeural",
|
434 |
-
"te-IN-ShrutiNeural",
|
435 |
-
"th-TH-NiwatNeural",
|
436 |
-
"th-TH-PremwadeeNeural",
|
437 |
-
"tr-TR-AhmetNeural",
|
438 |
-
"tr-TR-EmelNeural",
|
439 |
-
"uk-UA-OstapNeural",
|
440 |
-
"uk-UA-PolinaNeural",
|
441 |
-
"ur-IN-GulNeural",
|
442 |
-
"ur-IN-SalmanNeural",
|
443 |
-
"ur-PK-AsadNeural",
|
444 |
-
"ur-PK-UzmaNeural",
|
445 |
-
"uz-UZ-MadinaNeural",
|
446 |
-
"uz-UZ-SardorNeural",
|
447 |
-
"vi-VN-HoaiMyNeural",
|
448 |
-
"vi-VN-NamMinhNeural",
|
449 |
-
"cy-GB-AledNeural",
|
450 |
-
"cy-GB-NiaNeural",
|
451 |
-
"zu-ZA-ThandoNeural",
|
452 |
-
"zu-ZA-ThembaNeural"
|
453 |
-
],
|
454 |
-
"google_tts_voice": [
|
455 |
-
"af",
|
456 |
-
"am",
|
457 |
-
"ar",
|
458 |
-
"bg",
|
459 |
-
"bn",
|
460 |
-
"bs",
|
461 |
-
"ca",
|
462 |
-
"cs",
|
463 |
-
"cy",
|
464 |
-
"da",
|
465 |
-
"de",
|
466 |
-
"el",
|
467 |
-
"en",
|
468 |
-
"es",
|
469 |
-
"et",
|
470 |
-
"eu",
|
471 |
-
"fi",
|
472 |
-
"fr",
|
473 |
-
"fr-CA",
|
474 |
-
"gl",
|
475 |
-
"gu",
|
476 |
-
"ha",
|
477 |
-
"hi",
|
478 |
-
"hr",
|
479 |
-
"hu",
|
480 |
-
"id",
|
481 |
-
"is",
|
482 |
-
"it",
|
483 |
-
"iw",
|
484 |
-
"ja",
|
485 |
-
"jw",
|
486 |
-
"km",
|
487 |
-
"kn",
|
488 |
-
"ko",
|
489 |
-
"la",
|
490 |
-
"lt",
|
491 |
-
"lv",
|
492 |
-
"ml",
|
493 |
-
"mr",
|
494 |
-
"ms",
|
495 |
-
"my",
|
496 |
-
"ne",
|
497 |
-
"nl",
|
498 |
-
"no",
|
499 |
-
"pa",
|
500 |
-
"pl",
|
501 |
-
"pt",
|
502 |
-
"pt-PT",
|
503 |
-
"ro",
|
504 |
-
"ru",
|
505 |
-
"si",
|
506 |
-
"sk",
|
507 |
-
"sq",
|
508 |
-
"sr",
|
509 |
-
"su",
|
510 |
-
"sv",
|
511 |
-
"sw",
|
512 |
-
"ta",
|
513 |
-
"te",
|
514 |
-
"th",
|
515 |
-
"tl",
|
516 |
-
"tr",
|
517 |
-
"uk",
|
518 |
-
"ur",
|
519 |
-
"vi",
|
520 |
-
"yue",
|
521 |
-
"zh-CN",
|
522 |
-
"zh-TW",
|
523 |
-
"zh"
|
524 |
-
],
|
525 |
-
"fp16": true,
|
526 |
-
"separator_tab": true,
|
527 |
-
"convert_tab": true,
|
528 |
-
"convert_with_whisper": true,
|
529 |
-
"tts_tab": true,
|
530 |
-
"audioldm2": true,
|
531 |
-
"effects_tab": true,
|
532 |
-
"create_dataset_tab": true,
|
533 |
-
"training_tab": true,
|
534 |
-
"fushion_tab": true,
|
535 |
-
"read_tab": true,
|
536 |
-
"onnx_tab": true,
|
537 |
-
"downloads_tab": true,
|
538 |
-
"f0_extractor_tab": true,
|
539 |
-
"settings_tab": true,
|
540 |
-
"report_bug_tab": true,
|
541 |
-
"font": "https://fonts.googleapis.com/css2?family=Shadows+Into+Light&display=swap",
|
542 |
-
"app_port": 7860,
|
543 |
-
"tensorboard_port": 6870,
|
544 |
-
"num_of_restart": 5,
|
545 |
-
"server_name": "0.0.0.0",
|
546 |
-
"app_show_error": true
|
547 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/config.py
DELETED
@@ -1,90 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import json
|
3 |
-
import torch
|
4 |
-
|
5 |
-
|
6 |
-
version_config_paths = [os.path.join(version, size) for version in ["v1", "v2"] for size in ["32000.json", "40000.json", "48000.json"]]
|
7 |
-
|
8 |
-
def singleton(cls):
|
9 |
-
instances = {}
|
10 |
-
|
11 |
-
def get_instance(*args, **kwargs):
|
12 |
-
if cls not in instances: instances[cls] = cls(*args, **kwargs)
|
13 |
-
return instances[cls]
|
14 |
-
|
15 |
-
return get_instance
|
16 |
-
|
17 |
-
@singleton
|
18 |
-
class Config:
|
19 |
-
def __init__(self):
|
20 |
-
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
21 |
-
self.configs = json.load(open(os.path.join("main", "configs", "config.json"), "r"))
|
22 |
-
self.translations = self.multi_language()
|
23 |
-
self.json_config = self.load_config_json()
|
24 |
-
self.gpu_mem = None
|
25 |
-
self.per_preprocess = 3.7
|
26 |
-
self.is_half = self.is_fp16()
|
27 |
-
self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
|
28 |
-
|
29 |
-
def multi_language(self):
|
30 |
-
try:
|
31 |
-
lang = self.configs.get("language", "vi-VN")
|
32 |
-
if len([l for l in os.listdir(os.path.join("assets", "languages")) if l.endswith(".json")]) < 1: raise FileNotFoundError("Không tìm thấy bất cứ gói ngôn ngữ nào(No package languages found)")
|
33 |
-
|
34 |
-
if not lang: lang = "vi-VN"
|
35 |
-
if lang not in self.configs["support_language"]: raise ValueError("Ngôn ngữ không được hỗ trợ(Language not supported)")
|
36 |
-
|
37 |
-
lang_path = os.path.join("assets", "languages", f"{lang}.json")
|
38 |
-
if not os.path.exists(lang_path): lang_path = os.path.join("assets", "languages", "vi-VN.json")
|
39 |
-
|
40 |
-
with open(lang_path, encoding="utf-8") as f:
|
41 |
-
translations = json.load(f)
|
42 |
-
except json.JSONDecodeError:
|
43 |
-
print(self.translations["empty_json"].format(file=lang))
|
44 |
-
pass
|
45 |
-
|
46 |
-
return translations
|
47 |
-
|
48 |
-
def is_fp16(self):
|
49 |
-
fp16 = self.configs.get("fp16", False)
|
50 |
-
|
51 |
-
if self.device in ["cpu", "mps"] and fp16:
|
52 |
-
self.configs["fp16"] = False
|
53 |
-
fp16 = False
|
54 |
-
|
55 |
-
with open(os.path.join("main", "configs", "config.json"), "w") as f:
|
56 |
-
json.dump(self.configs, f, indent=4)
|
57 |
-
|
58 |
-
if not fp16: self.preprocess_per = 3.0
|
59 |
-
return fp16
|
60 |
-
|
61 |
-
def load_config_json(self):
|
62 |
-
configs = {}
|
63 |
-
|
64 |
-
for config_file in version_config_paths:
|
65 |
-
try:
|
66 |
-
with open(os.path.join("main", "configs", config_file), "r") as f:
|
67 |
-
configs[config_file] = json.load(f)
|
68 |
-
except json.JSONDecodeError:
|
69 |
-
print(self.translations["empty_json"].format(file=config_file))
|
70 |
-
pass
|
71 |
-
|
72 |
-
return configs
|
73 |
-
|
74 |
-
def device_config(self):
|
75 |
-
if self.device.startswith("cuda"): self.set_cuda_config()
|
76 |
-
elif self.has_mps(): self.device = "mps"
|
77 |
-
else: self.device = "cpu"
|
78 |
-
|
79 |
-
if self.gpu_mem is not None and self.gpu_mem <= 4:
|
80 |
-
self.preprocess_per = 3.0
|
81 |
-
return 1, 5, 30, 32
|
82 |
-
|
83 |
-
return (3, 10, 60, 65) if self.is_half else (1, 6, 38, 41)
|
84 |
-
|
85 |
-
def set_cuda_config(self):
|
86 |
-
i_device = int(self.device.split(":")[-1])
|
87 |
-
self.gpu_mem = torch.cuda.get_device_properties(i_device).total_memory // (1024**3)
|
88 |
-
|
89 |
-
def has_mps(self):
|
90 |
-
return torch.backends.mps.is_available()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/decrypt.bin
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:330268cbf6b9317a76510b533e1640ef48ed074a07c013e5b1abc4d48cfd9dce
|
3 |
-
size 32
|
|
|
|
|
|
|
|
main/configs/v1/32000.json
DELETED
@@ -1,46 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"train": {
|
3 |
-
"log_interval": 200,
|
4 |
-
"seed": 1234,
|
5 |
-
"epochs": 20000,
|
6 |
-
"learning_rate": 0.0001,
|
7 |
-
"betas": [0.8, 0.99],
|
8 |
-
"eps": 1e-09,
|
9 |
-
"batch_size": 4,
|
10 |
-
"lr_decay": 0.999875,
|
11 |
-
"segment_size": 12800,
|
12 |
-
"init_lr_ratio": 1,
|
13 |
-
"warmup_epochs": 0,
|
14 |
-
"c_mel": 45,
|
15 |
-
"c_kl": 1.0
|
16 |
-
},
|
17 |
-
"data": {
|
18 |
-
"max_wav_value": 32768.0,
|
19 |
-
"sample_rate": 32000,
|
20 |
-
"filter_length": 1024,
|
21 |
-
"hop_length": 320,
|
22 |
-
"win_length": 1024,
|
23 |
-
"n_mel_channels": 80,
|
24 |
-
"mel_fmin": 0.0,
|
25 |
-
"mel_fmax": null
|
26 |
-
},
|
27 |
-
"model": {
|
28 |
-
"inter_channels": 192,
|
29 |
-
"hidden_channels": 192,
|
30 |
-
"filter_channels": 768,
|
31 |
-
"text_enc_hidden_dim": 256,
|
32 |
-
"n_heads": 2,
|
33 |
-
"n_layers": 6,
|
34 |
-
"kernel_size": 3,
|
35 |
-
"p_dropout": 0,
|
36 |
-
"resblock": "1",
|
37 |
-
"resblock_kernel_sizes": [3, 7, 11],
|
38 |
-
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
39 |
-
"upsample_rates": [10, 4, 2, 2, 2],
|
40 |
-
"upsample_initial_channel": 512,
|
41 |
-
"upsample_kernel_sizes": [16, 16, 4, 4, 4],
|
42 |
-
"use_spectral_norm": false,
|
43 |
-
"gin_channels": 256,
|
44 |
-
"spk_embed_dim": 109
|
45 |
-
}
|
46 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v1/40000.json
DELETED
@@ -1,46 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"train": {
|
3 |
-
"log_interval": 200,
|
4 |
-
"seed": 1234,
|
5 |
-
"epochs": 20000,
|
6 |
-
"learning_rate": 0.0001,
|
7 |
-
"betas": [0.8, 0.99],
|
8 |
-
"eps": 1e-09,
|
9 |
-
"batch_size": 4,
|
10 |
-
"lr_decay": 0.999875,
|
11 |
-
"segment_size": 12800,
|
12 |
-
"init_lr_ratio": 1,
|
13 |
-
"warmup_epochs": 0,
|
14 |
-
"c_mel": 45,
|
15 |
-
"c_kl": 1.0
|
16 |
-
},
|
17 |
-
"data": {
|
18 |
-
"max_wav_value": 32768.0,
|
19 |
-
"sample_rate": 40000,
|
20 |
-
"filter_length": 2048,
|
21 |
-
"hop_length": 400,
|
22 |
-
"win_length": 2048,
|
23 |
-
"n_mel_channels": 125,
|
24 |
-
"mel_fmin": 0.0,
|
25 |
-
"mel_fmax": null
|
26 |
-
},
|
27 |
-
"model": {
|
28 |
-
"inter_channels": 192,
|
29 |
-
"hidden_channels": 192,
|
30 |
-
"filter_channels": 768,
|
31 |
-
"text_enc_hidden_dim": 256,
|
32 |
-
"n_heads": 2,
|
33 |
-
"n_layers": 6,
|
34 |
-
"kernel_size": 3,
|
35 |
-
"p_dropout": 0,
|
36 |
-
"resblock": "1",
|
37 |
-
"resblock_kernel_sizes": [3, 7, 11],
|
38 |
-
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
39 |
-
"upsample_rates": [10, 10, 2, 2],
|
40 |
-
"upsample_initial_channel": 512,
|
41 |
-
"upsample_kernel_sizes": [16, 16, 4, 4],
|
42 |
-
"use_spectral_norm": false,
|
43 |
-
"gin_channels": 256,
|
44 |
-
"spk_embed_dim": 109
|
45 |
-
}
|
46 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v1/48000.json
DELETED
@@ -1,46 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"train": {
|
3 |
-
"log_interval": 200,
|
4 |
-
"seed": 1234,
|
5 |
-
"epochs": 20000,
|
6 |
-
"learning_rate": 0.0001,
|
7 |
-
"betas": [0.8, 0.99],
|
8 |
-
"eps": 1e-09,
|
9 |
-
"batch_size": 4,
|
10 |
-
"lr_decay": 0.999875,
|
11 |
-
"segment_size": 11520,
|
12 |
-
"init_lr_ratio": 1,
|
13 |
-
"warmup_epochs": 0,
|
14 |
-
"c_mel": 45,
|
15 |
-
"c_kl": 1.0
|
16 |
-
},
|
17 |
-
"data": {
|
18 |
-
"max_wav_value": 32768.0,
|
19 |
-
"sample_rate": 48000,
|
20 |
-
"filter_length": 2048,
|
21 |
-
"hop_length": 480,
|
22 |
-
"win_length": 2048,
|
23 |
-
"n_mel_channels": 128,
|
24 |
-
"mel_fmin": 0.0,
|
25 |
-
"mel_fmax": null
|
26 |
-
},
|
27 |
-
"model": {
|
28 |
-
"inter_channels": 192,
|
29 |
-
"hidden_channels": 192,
|
30 |
-
"filter_channels": 768,
|
31 |
-
"text_enc_hidden_dim": 256,
|
32 |
-
"n_heads": 2,
|
33 |
-
"n_layers": 6,
|
34 |
-
"kernel_size": 3,
|
35 |
-
"p_dropout": 0,
|
36 |
-
"resblock": "1",
|
37 |
-
"resblock_kernel_sizes": [3, 7, 11],
|
38 |
-
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
39 |
-
"upsample_rates": [10, 6, 2, 2, 2],
|
40 |
-
"upsample_initial_channel": 512,
|
41 |
-
"upsample_kernel_sizes": [16, 16, 4, 4, 4],
|
42 |
-
"use_spectral_norm": false,
|
43 |
-
"gin_channels": 256,
|
44 |
-
"spk_embed_dim": 109
|
45 |
-
}
|
46 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v2/32000.json
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"train": {
|
3 |
-
"log_interval": 200,
|
4 |
-
"seed": 1234,
|
5 |
-
"learning_rate": 0.0001,
|
6 |
-
"betas": [0.8, 0.99],
|
7 |
-
"eps": 1e-09,
|
8 |
-
"lr_decay": 0.999875,
|
9 |
-
"segment_size": 12800,
|
10 |
-
"c_mel": 45,
|
11 |
-
"c_kl": 1.0
|
12 |
-
},
|
13 |
-
"data": {
|
14 |
-
"max_wav_value": 32768.0,
|
15 |
-
"sample_rate": 32000,
|
16 |
-
"filter_length": 1024,
|
17 |
-
"hop_length": 320,
|
18 |
-
"win_length": 1024,
|
19 |
-
"n_mel_channels": 80,
|
20 |
-
"mel_fmin": 0.0,
|
21 |
-
"mel_fmax": null
|
22 |
-
},
|
23 |
-
"model": {
|
24 |
-
"inter_channels": 192,
|
25 |
-
"hidden_channels": 192,
|
26 |
-
"filter_channels": 768,
|
27 |
-
"text_enc_hidden_dim": 768,
|
28 |
-
"n_heads": 2,
|
29 |
-
"n_layers": 6,
|
30 |
-
"kernel_size": 3,
|
31 |
-
"p_dropout": 0,
|
32 |
-
"resblock": "1",
|
33 |
-
"resblock_kernel_sizes": [3, 7, 11],
|
34 |
-
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
35 |
-
"upsample_rates": [10, 8, 2, 2],
|
36 |
-
"upsample_initial_channel": 512,
|
37 |
-
"upsample_kernel_sizes": [20, 16, 4, 4],
|
38 |
-
"use_spectral_norm": false,
|
39 |
-
"gin_channels": 256,
|
40 |
-
"spk_embed_dim": 109
|
41 |
-
}
|
42 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v2/40000.json
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"train": {
|
3 |
-
"log_interval": 200,
|
4 |
-
"seed": 1234,
|
5 |
-
"learning_rate": 0.0001,
|
6 |
-
"betas": [0.8, 0.99],
|
7 |
-
"eps": 1e-09,
|
8 |
-
"lr_decay": 0.999875,
|
9 |
-
"segment_size": 12800,
|
10 |
-
"c_mel": 45,
|
11 |
-
"c_kl": 1.0
|
12 |
-
},
|
13 |
-
"data": {
|
14 |
-
"max_wav_value": 32768.0,
|
15 |
-
"sample_rate": 40000,
|
16 |
-
"filter_length": 2048,
|
17 |
-
"hop_length": 400,
|
18 |
-
"win_length": 2048,
|
19 |
-
"n_mel_channels": 125,
|
20 |
-
"mel_fmin": 0.0,
|
21 |
-
"mel_fmax": null
|
22 |
-
},
|
23 |
-
"model": {
|
24 |
-
"inter_channels": 192,
|
25 |
-
"hidden_channels": 192,
|
26 |
-
"filter_channels": 768,
|
27 |
-
"text_enc_hidden_dim": 768,
|
28 |
-
"n_heads": 2,
|
29 |
-
"n_layers": 6,
|
30 |
-
"kernel_size": 3,
|
31 |
-
"p_dropout": 0,
|
32 |
-
"resblock": "1",
|
33 |
-
"resblock_kernel_sizes": [3, 7, 11],
|
34 |
-
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
35 |
-
"upsample_rates": [10, 10, 2, 2],
|
36 |
-
"upsample_initial_channel": 512,
|
37 |
-
"upsample_kernel_sizes": [16, 16, 4, 4],
|
38 |
-
"use_spectral_norm": false,
|
39 |
-
"gin_channels": 256,
|
40 |
-
"spk_embed_dim": 109
|
41 |
-
}
|
42 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v2/48000.json
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"train": {
|
3 |
-
"log_interval": 200,
|
4 |
-
"seed": 1234,
|
5 |
-
"learning_rate": 0.0001,
|
6 |
-
"betas": [0.8, 0.99],
|
7 |
-
"eps": 1e-09,
|
8 |
-
"lr_decay": 0.999875,
|
9 |
-
"segment_size": 17280,
|
10 |
-
"c_mel": 45,
|
11 |
-
"c_kl": 1.0
|
12 |
-
},
|
13 |
-
"data": {
|
14 |
-
"max_wav_value": 32768.0,
|
15 |
-
"sample_rate": 48000,
|
16 |
-
"filter_length": 2048,
|
17 |
-
"hop_length": 480,
|
18 |
-
"win_length": 2048,
|
19 |
-
"n_mel_channels": 128,
|
20 |
-
"mel_fmin": 0.0,
|
21 |
-
"mel_fmax": null
|
22 |
-
},
|
23 |
-
"model": {
|
24 |
-
"inter_channels": 192,
|
25 |
-
"hidden_channels": 192,
|
26 |
-
"filter_channels": 768,
|
27 |
-
"text_enc_hidden_dim": 768,
|
28 |
-
"n_heads": 2,
|
29 |
-
"n_layers": 6,
|
30 |
-
"kernel_size": 3,
|
31 |
-
"p_dropout": 0,
|
32 |
-
"resblock": "1",
|
33 |
-
"resblock_kernel_sizes": [3, 7, 11],
|
34 |
-
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
35 |
-
"upsample_rates": [12, 10, 2, 2],
|
36 |
-
"upsample_initial_channel": 512,
|
37 |
-
"upsample_kernel_sizes": [24, 20, 4, 4],
|
38 |
-
"use_spectral_norm": false,
|
39 |
-
"gin_channels": 256,
|
40 |
-
"spk_embed_dim": 109
|
41 |
-
}
|
42 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/audio_effects.py
DELETED
@@ -1,180 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import librosa
|
4 |
-
import argparse
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import soundfile as sf
|
8 |
-
|
9 |
-
from distutils.util import strtobool
|
10 |
-
from scipy.signal import butter, filtfilt
|
11 |
-
from pedalboard import Pedalboard, Chorus, Distortion, Reverb, PitchShift, Delay, Limiter, Gain, Bitcrush, Clipping, Compressor, Phaser, HighpassFilter
|
12 |
-
|
13 |
-
sys.path.append(os.getcwd())
|
14 |
-
|
15 |
-
from main.configs.config import Config
|
16 |
-
from main.library.utils import pydub_convert, pydub_load
|
17 |
-
|
18 |
-
translations = Config().translations
|
19 |
-
|
20 |
-
def parse_arguments():
|
21 |
-
parser = argparse.ArgumentParser()
|
22 |
-
parser.add_argument("--input_path", type=str, required=True)
|
23 |
-
parser.add_argument("--output_path", type=str, default="./audios/apply_effects.wav")
|
24 |
-
parser.add_argument("--export_format", type=str, default="wav")
|
25 |
-
parser.add_argument("--resample", type=lambda x: bool(strtobool(x)), default=False)
|
26 |
-
parser.add_argument("--resample_sr", type=int, default=0)
|
27 |
-
parser.add_argument("--chorus", type=lambda x: bool(strtobool(x)), default=False)
|
28 |
-
parser.add_argument("--chorus_depth", type=float, default=0.5)
|
29 |
-
parser.add_argument("--chorus_rate", type=float, default=1.5)
|
30 |
-
parser.add_argument("--chorus_mix", type=float, default=0.5)
|
31 |
-
parser.add_argument("--chorus_delay", type=int, default=10)
|
32 |
-
parser.add_argument("--chorus_feedback", type=float, default=0)
|
33 |
-
parser.add_argument("--distortion", type=lambda x: bool(strtobool(x)), default=False)
|
34 |
-
parser.add_argument("--drive_db", type=int, default=20)
|
35 |
-
parser.add_argument("--reverb", type=lambda x: bool(strtobool(x)), default=False)
|
36 |
-
parser.add_argument("--reverb_room_size", type=float, default=0.5)
|
37 |
-
parser.add_argument("--reverb_damping", type=float, default=0.5)
|
38 |
-
parser.add_argument("--reverb_wet_level", type=float, default=0.33)
|
39 |
-
parser.add_argument("--reverb_dry_level", type=float, default=0.67)
|
40 |
-
parser.add_argument("--reverb_width", type=float, default=1)
|
41 |
-
parser.add_argument("--reverb_freeze_mode", type=lambda x: bool(strtobool(x)), default=False)
|
42 |
-
parser.add_argument("--pitchshift", type=lambda x: bool(strtobool(x)), default=False)
|
43 |
-
parser.add_argument("--pitch_shift", type=int, default=0)
|
44 |
-
parser.add_argument("--delay", type=lambda x: bool(strtobool(x)), default=False)
|
45 |
-
parser.add_argument("--delay_seconds", type=float, default=0.5)
|
46 |
-
parser.add_argument("--delay_feedback", type=float, default=0.5)
|
47 |
-
parser.add_argument("--delay_mix", type=float, default=0.5)
|
48 |
-
parser.add_argument("--compressor", type=lambda x: bool(strtobool(x)), default=False)
|
49 |
-
parser.add_argument("--compressor_threshold", type=int, default=-20)
|
50 |
-
parser.add_argument("--compressor_ratio", type=float, default=4)
|
51 |
-
parser.add_argument("--compressor_attack_ms", type=float, default=10)
|
52 |
-
parser.add_argument("--compressor_release_ms", type=int, default=200)
|
53 |
-
parser.add_argument("--limiter", type=lambda x: bool(strtobool(x)), default=False)
|
54 |
-
parser.add_argument("--limiter_threshold", type=int, default=0)
|
55 |
-
parser.add_argument("--limiter_release", type=int, default=100)
|
56 |
-
parser.add_argument("--gain", type=lambda x: bool(strtobool(x)), default=False)
|
57 |
-
parser.add_argument("--gain_db", type=int, default=0)
|
58 |
-
parser.add_argument("--bitcrush", type=lambda x: bool(strtobool(x)), default=False)
|
59 |
-
parser.add_argument("--bitcrush_bit_depth", type=int, default=16)
|
60 |
-
parser.add_argument("--clipping", type=lambda x: bool(strtobool(x)), default=False)
|
61 |
-
parser.add_argument("--clipping_threshold", type=int, default=-10)
|
62 |
-
parser.add_argument("--phaser", type=lambda x: bool(strtobool(x)), default=False)
|
63 |
-
parser.add_argument("--phaser_rate_hz", type=float, default=0.5)
|
64 |
-
parser.add_argument("--phaser_depth", type=float, default=0.5)
|
65 |
-
parser.add_argument("--phaser_centre_frequency_hz", type=int, default=1000)
|
66 |
-
parser.add_argument("--phaser_feedback", type=float, default=0)
|
67 |
-
parser.add_argument("--phaser_mix", type=float, default=0.5)
|
68 |
-
parser.add_argument("--treble_bass_boost", type=lambda x: bool(strtobool(x)), default=False)
|
69 |
-
parser.add_argument("--bass_boost_db", type=int, default=0)
|
70 |
-
parser.add_argument("--bass_boost_frequency", type=int, default=100)
|
71 |
-
parser.add_argument("--treble_boost_db", type=int, default=0)
|
72 |
-
parser.add_argument("--treble_boost_frequency", type=int, default=3000)
|
73 |
-
parser.add_argument("--fade_in_out", type=lambda x: bool(strtobool(x)), default=False)
|
74 |
-
parser.add_argument("--fade_in_duration", type=float, default=2000)
|
75 |
-
parser.add_argument("--fade_out_duration", type=float, default=2000)
|
76 |
-
parser.add_argument("--audio_combination", type=lambda x: bool(strtobool(x)), default=False)
|
77 |
-
parser.add_argument("--audio_combination_input", type=str)
|
78 |
-
|
79 |
-
return parser.parse_args()
|
80 |
-
|
81 |
-
def process_audio(input_path, output_path, resample, resample_sr, chorus_depth, chorus_rate, chorus_mix, chorus_delay, chorus_feedback, distortion_drive, reverb_room_size, reverb_damping, reverb_wet_level, reverb_dry_level, reverb_width, reverb_freeze_mode, pitch_shift, delay_seconds, delay_feedback, delay_mix, compressor_threshold, compressor_ratio, compressor_attack_ms, compressor_release_ms, limiter_threshold, limiter_release, gain_db, bitcrush_bit_depth, clipping_threshold, phaser_rate_hz, phaser_depth, phaser_centre_frequency_hz, phaser_feedback, phaser_mix, bass_boost_db, bass_boost_frequency, treble_boost_db, treble_boost_frequency, fade_in_duration, fade_out_duration, export_format, chorus, distortion, reverb, pitchshift, delay, compressor, limiter, gain, bitcrush, clipping, phaser, treble_bass_boost, fade_in_out, audio_combination, audio_combination_input):
|
82 |
-
def bass_boost(audio, gain_db, frequency, sample_rate):
|
83 |
-
if gain_db >= 1:
|
84 |
-
b, a = butter(4, frequency / (0.5 * sample_rate), btype='low')
|
85 |
-
|
86 |
-
return filtfilt(b, a, audio) * 10 ** (gain_db / 20)
|
87 |
-
else: return audio
|
88 |
-
|
89 |
-
def treble_boost(audio, gain_db, frequency, sample_rate):
|
90 |
-
if gain_db >=1:
|
91 |
-
b, a = butter(4, frequency / (0.5 * sample_rate), btype='high')
|
92 |
-
|
93 |
-
return filtfilt(b, a, audio) * 10 ** (gain_db / 20)
|
94 |
-
else: return audio
|
95 |
-
|
96 |
-
def fade_out_effect(audio, sr, duration=3.0):
|
97 |
-
length = int(duration * sr)
|
98 |
-
end = audio.shape[0]
|
99 |
-
|
100 |
-
if length > end: length = end
|
101 |
-
start = end - length
|
102 |
-
|
103 |
-
audio[start:end] = audio[start:end] * np.linspace(1.0, 0.0, length)
|
104 |
-
return audio
|
105 |
-
|
106 |
-
def fade_in_effect(audio, sr, duration=3.0):
|
107 |
-
length = int(duration * sr)
|
108 |
-
start = 0
|
109 |
-
|
110 |
-
if length > audio.shape[0]: length = audio.shape[0]
|
111 |
-
end = length
|
112 |
-
|
113 |
-
audio[start:end] = audio[start:end] * np.linspace(0.0, 1.0, length)
|
114 |
-
return audio
|
115 |
-
|
116 |
-
if not input_path or not os.path.exists(input_path):
|
117 |
-
print(translations["input_not_valid"])
|
118 |
-
sys.exit(1)
|
119 |
-
|
120 |
-
if not output_path:
|
121 |
-
print(translations["output_not_valid"])
|
122 |
-
sys.exit(1)
|
123 |
-
|
124 |
-
if os.path.exists(output_path): os.remove(output_path)
|
125 |
-
|
126 |
-
try:
|
127 |
-
input_path = input_path.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
|
128 |
-
|
129 |
-
try:
|
130 |
-
audio, sample_rate = sf.read(input_path, dtype=np.float32)
|
131 |
-
except:
|
132 |
-
audio, sample_rate = librosa.load(input_path, sr=None)
|
133 |
-
except Exception as e:
|
134 |
-
raise RuntimeError(f"{translations['errors_loading_audio']}: {e}")
|
135 |
-
|
136 |
-
audio = audio.flatten()
|
137 |
-
|
138 |
-
try:
|
139 |
-
board = Pedalboard([HighpassFilter()])
|
140 |
-
|
141 |
-
if chorus: board.append(Chorus(depth=chorus_depth, rate_hz=chorus_rate, mix=chorus_mix, centre_delay_ms=chorus_delay, feedback=chorus_feedback))
|
142 |
-
if distortion: board.append(Distortion(drive_db=distortion_drive))
|
143 |
-
if reverb: board.append(Reverb(room_size=reverb_room_size, damping=reverb_damping, wet_level=reverb_wet_level, dry_level=reverb_dry_level, width=reverb_width, freeze_mode=1 if reverb_freeze_mode else 0))
|
144 |
-
if pitchshift: board.append(PitchShift(semitones=pitch_shift))
|
145 |
-
if delay: board.append(Delay(delay_seconds=delay_seconds, feedback=delay_feedback, mix=delay_mix))
|
146 |
-
if compressor: board.append(Compressor(threshold_db=compressor_threshold, ratio=compressor_ratio, attack_ms=compressor_attack_ms, release_ms=compressor_release_ms))
|
147 |
-
if limiter: board.append(Limiter(threshold_db=limiter_threshold, release_ms=limiter_release))
|
148 |
-
if gain: board.append(Gain(gain_db=gain_db))
|
149 |
-
if bitcrush: board.append(Bitcrush(bit_depth=bitcrush_bit_depth))
|
150 |
-
if clipping: board.append(Clipping(threshold_db=clipping_threshold))
|
151 |
-
if phaser: board.append(Phaser(rate_hz=phaser_rate_hz, depth=phaser_depth, centre_frequency_hz=phaser_centre_frequency_hz, feedback=phaser_feedback, mix=phaser_mix))
|
152 |
-
|
153 |
-
processed_audio = board(audio, sample_rate)
|
154 |
-
|
155 |
-
if treble_bass_boost:
|
156 |
-
processed_audio = bass_boost(processed_audio, bass_boost_db, bass_boost_frequency, sample_rate)
|
157 |
-
processed_audio = treble_boost(processed_audio, treble_boost_db, treble_boost_frequency, sample_rate)
|
158 |
-
|
159 |
-
if fade_in_out:
|
160 |
-
processed_audio = fade_in_effect(processed_audio, sample_rate, fade_in_duration)
|
161 |
-
processed_audio = fade_out_effect(processed_audio, sample_rate, fade_out_duration)
|
162 |
-
|
163 |
-
if resample_sr != sample_rate and resample_sr > 0 and resample:
|
164 |
-
target_sr = min([8000, 11025, 12000, 16000, 22050, 24000, 32000, 44100, 48000, 96000], key=lambda x: abs(x - resample_sr))
|
165 |
-
processed_audio = librosa.resample(processed_audio, orig_sr=sample_rate, target_sr=target_sr, res_type="soxr_vhq")
|
166 |
-
sample_rate = target_sr
|
167 |
-
|
168 |
-
sf.write(output_path.replace("wav", export_format), processed_audio, sample_rate, format=export_format)
|
169 |
-
|
170 |
-
if audio_combination: pydub_convert(pydub_load(audio_combination_input)).overlay(pydub_convert(pydub_load(output_path.replace("wav", export_format)))).export(output_path.replace("wav", export_format), format=export_format)
|
171 |
-
except Exception as e:
|
172 |
-
raise RuntimeError(translations["apply_error"].format(e=e))
|
173 |
-
|
174 |
-
return output_path
|
175 |
-
|
176 |
-
def main():
|
177 |
-
args = parse_arguments()
|
178 |
-
process_audio(input_path=args.input_path, output_path=args.output_path, resample=args.resample, resample_sr=args.resample_sr, chorus_depth=args.chorus_depth, chorus_rate=args.chorus_rate, chorus_mix=args.chorus_mix, chorus_delay=args.chorus_delay, chorus_feedback=args.chorus_feedback, distortion_drive=args.drive_db, reverb_room_size=args.reverb_room_size, reverb_damping=args.reverb_damping, reverb_wet_level=args.reverb_wet_level, reverb_dry_level=args.reverb_dry_level, reverb_width=args.reverb_width, reverb_freeze_mode=args.reverb_freeze_mode, pitch_shift=args.pitch_shift, delay_seconds=args.delay_seconds, delay_feedback=args.delay_feedback, delay_mix=args.delay_mix, compressor_threshold=args.compressor_threshold, compressor_ratio=args.compressor_ratio, compressor_attack_ms=args.compressor_attack_ms, compressor_release_ms=args.compressor_release_ms, limiter_threshold=args.limiter_threshold, limiter_release=args.limiter_release, gain_db=args.gain_db, bitcrush_bit_depth=args.bitcrush_bit_depth, clipping_threshold=args.clipping_threshold, phaser_rate_hz=args.phaser_rate_hz, phaser_depth=args.phaser_depth, phaser_centre_frequency_hz=args.phaser_centre_frequency_hz, phaser_feedback=args.phaser_feedback, phaser_mix=args.phaser_mix, bass_boost_db=args.bass_boost_db, bass_boost_frequency=args.bass_boost_frequency, treble_boost_db=args.treble_boost_db, treble_boost_frequency=args.treble_boost_frequency, fade_in_duration=args.fade_in_duration, fade_out_duration=args.fade_out_duration, export_format=args.export_format, chorus=args.chorus, distortion=args.distortion, reverb=args.reverb, pitchshift=args.pitchshift, delay=args.delay, compressor=args.compressor, limiter=args.limiter, gain=args.gain, bitcrush=args.bitcrush, clipping=args.clipping, phaser=args.phaser, treble_bass_boost=args.treble_bass_boost, fade_in_out=args.fade_in_out, audio_combination=args.audio_combination, audio_combination_input=args.audio_combination_input)
|
179 |
-
|
180 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/audioldm2.py
DELETED
@@ -1,210 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import time
|
4 |
-
import tqdm
|
5 |
-
import torch
|
6 |
-
import logging
|
7 |
-
import librosa
|
8 |
-
import argparse
|
9 |
-
import scipy.signal
|
10 |
-
import logging.handlers
|
11 |
-
|
12 |
-
import numpy as np
|
13 |
-
import soundfile as sf
|
14 |
-
|
15 |
-
from torch import inference_mode
|
16 |
-
from distutils.util import strtobool
|
17 |
-
|
18 |
-
sys.path.append(os.getcwd())
|
19 |
-
|
20 |
-
from main.configs.config import Config
|
21 |
-
from main.library.audioldm2.utils import load_audio
|
22 |
-
from main.library.audioldm2.models import load_model
|
23 |
-
|
24 |
-
config = Config()
|
25 |
-
translations = config.translations
|
26 |
-
logger = logging.getLogger(__name__)
|
27 |
-
logger.propagate = False
|
28 |
-
|
29 |
-
for l in ["torch", "httpx", "httpcore", "diffusers", "transformers"]:
|
30 |
-
logging.getLogger(l).setLevel(logging.ERROR)
|
31 |
-
|
32 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
33 |
-
else:
|
34 |
-
console_handler = logging.StreamHandler()
|
35 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
36 |
-
console_handler.setFormatter(console_formatter)
|
37 |
-
console_handler.setLevel(logging.INFO)
|
38 |
-
file_handler = logging.handlers.RotatingFileHandler(os.path.join("assets", "logs", "audioldm2.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
39 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
40 |
-
file_handler.setFormatter(file_formatter)
|
41 |
-
file_handler.setLevel(logging.DEBUG)
|
42 |
-
logger.addHandler(console_handler)
|
43 |
-
logger.addHandler(file_handler)
|
44 |
-
logger.setLevel(logging.DEBUG)
|
45 |
-
|
46 |
-
def parse_arguments():
|
47 |
-
parser = argparse.ArgumentParser()
|
48 |
-
parser.add_argument("--input_path", type=str, required=True)
|
49 |
-
parser.add_argument("--output_path", type=str, default="./output.wav")
|
50 |
-
parser.add_argument("--export_format", type=str, default="wav")
|
51 |
-
parser.add_argument("--sample_rate", type=int, default=44100)
|
52 |
-
parser.add_argument("--audioldm_model", type=str, default="audioldm2-music")
|
53 |
-
parser.add_argument("--source_prompt", type=str, default="")
|
54 |
-
parser.add_argument("--target_prompt", type=str, default="")
|
55 |
-
parser.add_argument("--steps", type=int, default=200)
|
56 |
-
parser.add_argument("--cfg_scale_src", type=float, default=3.5)
|
57 |
-
parser.add_argument("--cfg_scale_tar", type=float, default=12)
|
58 |
-
parser.add_argument("--t_start", type=int, default=45)
|
59 |
-
parser.add_argument("--save_compute", type=lambda x: bool(strtobool(x)), default=False)
|
60 |
-
|
61 |
-
return parser.parse_args()
|
62 |
-
|
63 |
-
def main():
|
64 |
-
args = parse_arguments()
|
65 |
-
input_path, output_path, export_format, sample_rate, audioldm_model, source_prompt, target_prompt, steps, cfg_scale_src, cfg_scale_tar, t_start, save_compute = args.input_path, args.output_path, args.export_format, args.sample_rate, args.audioldm_model, args.source_prompt, args.target_prompt, args.steps, args.cfg_scale_src, args.cfg_scale_tar, args.t_start, args.save_compute
|
66 |
-
|
67 |
-
log_data = {translations['audio_path']: input_path, translations['output_path']: output_path.replace('wav', export_format), translations['model_name']: audioldm_model, translations['export_format']: export_format, translations['sample_rate']: sample_rate, translations['steps']: steps, translations['source_prompt']: source_prompt, translations['target_prompt']: target_prompt, translations['cfg_scale_src']: cfg_scale_src, translations['cfg_scale_tar']: cfg_scale_tar, translations['t_start']: t_start, translations['save_compute']: save_compute}
|
68 |
-
|
69 |
-
for key, value in log_data.items():
|
70 |
-
logger.debug(f"{key}: {value}")
|
71 |
-
|
72 |
-
start_time = time.time()
|
73 |
-
logger.info(translations["start_edit"].format(input_path=input_path))
|
74 |
-
pid_path = os.path.join("assets", "audioldm2_pid.txt")
|
75 |
-
with open(pid_path, "w") as pid_file:
|
76 |
-
pid_file.write(str(os.getpid()))
|
77 |
-
|
78 |
-
try:
|
79 |
-
edit(input_path, output_path, audioldm_model, source_prompt, target_prompt, steps, cfg_scale_src, cfg_scale_tar, t_start, save_compute, sample_rate, config.device, export_format=export_format)
|
80 |
-
except Exception as e:
|
81 |
-
logger.error(translations["error_edit"].format(e=e))
|
82 |
-
import traceback
|
83 |
-
logger.debug(traceback.format_exc())
|
84 |
-
|
85 |
-
logger.info(translations["edit_success"].format(time=f"{(time.time() - start_time):.2f}", output_path=output_path.replace('wav', export_format)))
|
86 |
-
|
87 |
-
def invert(ldm_stable, x0, prompt_src, num_diffusion_steps, cfg_scale_src, duration, save_compute):
|
88 |
-
with inference_mode():
|
89 |
-
w0 = ldm_stable.vae_encode(x0)
|
90 |
-
|
91 |
-
_, zs, wts, extra_info = inversion_forward_process(ldm_stable, w0, etas=1, prompts=[prompt_src], cfg_scales=[cfg_scale_src], num_inference_steps=num_diffusion_steps, numerical_fix=True, duration=duration, save_compute=save_compute)
|
92 |
-
return zs, wts, extra_info
|
93 |
-
|
94 |
-
def low_pass_filter(audio, cutoff=7500, sr=16000):
|
95 |
-
b, a = scipy.signal.butter(4, cutoff / (sr / 2), btype='low')
|
96 |
-
return scipy.signal.filtfilt(b, a, audio)
|
97 |
-
|
98 |
-
def sample(output_audio, sr, ldm_stable, zs, wts, extra_info, prompt_tar, tstart, cfg_scale_tar, duration, save_compute, export_format = "wav"):
|
99 |
-
tstart = torch.tensor(tstart, dtype=torch.int32)
|
100 |
-
w0, _ = inversion_reverse_process(ldm_stable, xT=wts, tstart=tstart, etas=1., prompts=[prompt_tar], neg_prompts=[""], cfg_scales=[cfg_scale_tar], zs=zs[:int(tstart)], duration=duration, extra_info=extra_info, save_compute=save_compute)
|
101 |
-
|
102 |
-
with inference_mode():
|
103 |
-
x0_dec = ldm_stable.vae_decode(w0.to(torch.float16 if config.is_half else torch.float32))
|
104 |
-
|
105 |
-
if x0_dec.dim() < 4: x0_dec = x0_dec[None, :, :, :]
|
106 |
-
|
107 |
-
with torch.no_grad():
|
108 |
-
audio = ldm_stable.decode_to_mel(x0_dec.to(torch.float16 if config.is_half else torch.float32))
|
109 |
-
|
110 |
-
audio = audio.float().squeeze().cpu().numpy()
|
111 |
-
orig_sr = 16000
|
112 |
-
|
113 |
-
if sr != 16000 and sr > 0:
|
114 |
-
audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=sr, res_type="soxr_vhq")
|
115 |
-
orig_sr = sr
|
116 |
-
|
117 |
-
audio = low_pass_filter(audio, 7500, orig_sr)
|
118 |
-
|
119 |
-
sf.write(output_audio, np.tile(audio, (2, 1)).T, orig_sr, format=export_format)
|
120 |
-
return output_audio
|
121 |
-
|
122 |
-
def edit(input_audio, output_audio, model_id, source_prompt = "", target_prompt = "", steps = 200, cfg_scale_src = 3.5, cfg_scale_tar = 12, t_start = 45, save_compute = True, sr = 44100, device = "cpu", export_format = "wav"):
|
123 |
-
ldm_stable = load_model(model_id, device=device)
|
124 |
-
ldm_stable.model.scheduler.set_timesteps(steps, device=device)
|
125 |
-
x0, duration = load_audio(input_audio, ldm_stable.get_melspectrogram(), device=device)
|
126 |
-
zs_tensor, wts_tensor, extra_info_list = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt, num_diffusion_steps=steps, cfg_scale_src=cfg_scale_src, duration=duration, save_compute=save_compute)
|
127 |
-
|
128 |
-
return sample(output_audio, sr, ldm_stable, zs_tensor, wts_tensor, extra_info_list, prompt_tar=target_prompt, tstart=int(t_start / 100 * steps), cfg_scale_tar=cfg_scale_tar, duration=duration, save_compute=save_compute, export_format=export_format)
|
129 |
-
|
130 |
-
def inversion_forward_process(model, x0, etas = None, prompts = [""], cfg_scales = [3.5], num_inference_steps = 50, numerical_fix = False, duration = None, first_order = False, save_compute = True):
|
131 |
-
if len(prompts) > 1 or prompts[0] != "":
|
132 |
-
text_embeddings_hidden_states, text_embeddings_class_labels, text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
|
133 |
-
uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text([""], negative=True, save_compute=save_compute, cond_length=text_embeddings_class_labels.shape[1] if text_embeddings_class_labels is not None else None)
|
134 |
-
else: uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text([""], negative=True, save_compute=False)
|
135 |
-
|
136 |
-
timesteps = model.model.scheduler.timesteps.to(model.device)
|
137 |
-
variance_noise_shape = model.get_noise_shape(x0, num_inference_steps)
|
138 |
-
|
139 |
-
if type(etas) in [int, float]: etas = [etas]*model.model.scheduler.num_inference_steps
|
140 |
-
|
141 |
-
xts = model.sample_xts_from_x0(x0, num_inference_steps=num_inference_steps)
|
142 |
-
zs = torch.zeros(size=variance_noise_shape, device=model.device)
|
143 |
-
extra_info = [None] * len(zs)
|
144 |
-
|
145 |
-
if timesteps[0].dtype == torch.int64: t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
146 |
-
elif timesteps[0].dtype == torch.float32: t_to_idx = {float(v): k for k, v in enumerate(timesteps)}
|
147 |
-
|
148 |
-
xt = x0
|
149 |
-
model.setup_extra_inputs(xt, init_timestep=timesteps[0], audio_end_in_s=duration, save_compute=save_compute and prompts[0] != "")
|
150 |
-
|
151 |
-
for t in tqdm.tqdm(timesteps, desc=translations["inverting"], ncols=100, unit="a"):
|
152 |
-
idx = num_inference_steps - t_to_idx[int(t) if timesteps[0].dtype == torch.int64 else float(t)] - 1
|
153 |
-
xt = xts[idx + 1][None]
|
154 |
-
xt_inp = model.model.scheduler.scale_model_input(xt, t).to(torch.float16 if config.is_half else torch.float32)
|
155 |
-
|
156 |
-
with torch.no_grad():
|
157 |
-
if save_compute and prompts[0] != "":
|
158 |
-
comb_out, _, _ = model.unet_forward(xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1), timestep=t, encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states], dim=0) if uncond_embeddings_hidden_states is not None else None, class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0) if uncond_embeddings_class_lables is not None else None, encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask], dim=0) if uncond_boolean_prompt_mask is not None else None)
|
159 |
-
out, cond_out = comb_out.sample.chunk(2, dim=0)
|
160 |
-
else:
|
161 |
-
out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=uncond_embeddings_hidden_states, class_labels=uncond_embeddings_class_lables, encoder_attention_mask=uncond_boolean_prompt_mask)[0].sample
|
162 |
-
if len(prompts) > 1 or prompts[0] != "": cond_out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=text_embeddings_hidden_states, class_labels=text_embeddings_class_labels, encoder_attention_mask=text_embeddings_boolean_prompt_mask)[0].sample
|
163 |
-
|
164 |
-
if len(prompts) > 1 or prompts[0] != "": noise_pred = out + (cfg_scales[0] * (cond_out - out)).sum(axis=0).unsqueeze(0)
|
165 |
-
else: noise_pred = out
|
166 |
-
|
167 |
-
xtm1 = xts[idx][None]
|
168 |
-
z, xtm1, extra = model.get_zs_from_xts(xt, xtm1, noise_pred, t, eta=etas[idx], numerical_fix=numerical_fix, first_order=first_order)
|
169 |
-
zs[idx] = z
|
170 |
-
xts[idx] = xtm1
|
171 |
-
extra_info[idx] = extra
|
172 |
-
|
173 |
-
if zs is not None: zs[0] = torch.zeros_like(zs[0])
|
174 |
-
return xt, zs, xts, extra_info
|
175 |
-
|
176 |
-
def inversion_reverse_process(model, xT, tstart, etas = 0, prompts = [""], neg_prompts = [""], cfg_scales = None, zs = None, duration = None, first_order = False, extra_info = None, save_compute = True):
|
177 |
-
text_embeddings_hidden_states, text_embeddings_class_labels, text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
|
178 |
-
uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text(neg_prompts, negative=True, save_compute=save_compute, cond_length=text_embeddings_class_labels.shape[1] if text_embeddings_class_labels is not None else None)
|
179 |
-
xt = xT[tstart.max()].unsqueeze(0)
|
180 |
-
|
181 |
-
if etas is None: etas = 0
|
182 |
-
if type(etas) in [int, float]: etas = [etas]*model.model.scheduler.num_inference_steps
|
183 |
-
|
184 |
-
assert len(etas) == model.model.scheduler.num_inference_steps
|
185 |
-
timesteps = model.model.scheduler.timesteps.to(model.device)
|
186 |
-
|
187 |
-
if timesteps[0].dtype == torch.int64: t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
|
188 |
-
elif timesteps[0].dtype == torch.float32: t_to_idx = {float(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
|
189 |
-
|
190 |
-
model.setup_extra_inputs(xt, extra_info=extra_info, init_timestep=timesteps[-zs.shape[0]], audio_end_in_s=duration, save_compute=save_compute)
|
191 |
-
|
192 |
-
for t in tqdm.tqdm(timesteps[-zs.shape[0]:], desc=translations["editing"], ncols=100, unit="a"):
|
193 |
-
idx = model.model.scheduler.num_inference_steps - t_to_idx[int(t) if timesteps[0].dtype == torch.int64 else float(t)] - (model.model.scheduler.num_inference_steps - zs.shape[0] + 1)
|
194 |
-
xt_inp = model.model.scheduler.scale_model_input(xt, t).to(torch.float16 if config.is_half else torch.float32)
|
195 |
-
|
196 |
-
with torch.no_grad():
|
197 |
-
if save_compute:
|
198 |
-
comb_out, _, _ = model.unet_forward(xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1), timestep=t, encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states], dim=0) if uncond_embeddings_hidden_states is not None else None, class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0) if uncond_embeddings_class_lables is not None else None, encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask], dim=0) if uncond_boolean_prompt_mask is not None else None)
|
199 |
-
uncond_out, cond_out = comb_out.sample.chunk(2, dim=0)
|
200 |
-
else:
|
201 |
-
uncond_out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=uncond_embeddings_hidden_states, class_labels=uncond_embeddings_class_lables, encoder_attention_mask=uncond_boolean_prompt_mask)[0].sample
|
202 |
-
cond_out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=text_embeddings_hidden_states, class_labels=text_embeddings_class_labels, encoder_attention_mask=text_embeddings_boolean_prompt_mask)[0].sample
|
203 |
-
|
204 |
-
z = zs[idx] if zs is not None else None
|
205 |
-
noise_pred = uncond_out + (cfg_scales[0] * (cond_out - uncond_out)).sum(axis=0).unsqueeze(0)
|
206 |
-
xt = model.reverse_step_with_custom_noise(noise_pred, t, xt, variance_noise=z.unsqueeze(0), eta=etas[idx], first_order=first_order)
|
207 |
-
|
208 |
-
return xt, zs
|
209 |
-
|
210 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/convert.py
DELETED
@@ -1,590 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
import os
|
3 |
-
import gc
|
4 |
-
import sys
|
5 |
-
import time
|
6 |
-
import faiss
|
7 |
-
import torch
|
8 |
-
import librosa
|
9 |
-
import logging
|
10 |
-
import argparse
|
11 |
-
import warnings
|
12 |
-
import onnxruntime
|
13 |
-
import logging.handlers
|
14 |
-
|
15 |
-
import numpy as np
|
16 |
-
import soundfile as sf
|
17 |
-
import torch.nn.functional as F
|
18 |
-
|
19 |
-
from tqdm import tqdm
|
20 |
-
from scipy import signal
|
21 |
-
from distutils.util import strtobool
|
22 |
-
|
23 |
-
warnings.filterwarnings("ignore")
|
24 |
-
sys.path.append(os.getcwd())
|
25 |
-
|
26 |
-
from main.configs.config import Config
|
27 |
-
from main.library.algorithm.synthesizers import Synthesizer
|
28 |
-
from main.library.utils import check_predictors, check_embedders, load_audio, load_embedders_model, cut, restore
|
29 |
-
|
30 |
-
bh, ah = signal.butter(N=5, Wn=48, btype="high", fs=16000)
|
31 |
-
config = Config()
|
32 |
-
translations = config.translations
|
33 |
-
logger = logging.getLogger(__name__)
|
34 |
-
logger.propagate = False
|
35 |
-
|
36 |
-
for l in ["torch", "faiss", "httpx", "fairseq", "httpcore", "faiss.loader", "numba.core", "urllib3", "transformers", "matplotlib"]:
|
37 |
-
logging.getLogger(l).setLevel(logging.ERROR)
|
38 |
-
|
39 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
40 |
-
else:
|
41 |
-
console_handler = logging.StreamHandler()
|
42 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
43 |
-
console_handler.setFormatter(console_formatter)
|
44 |
-
console_handler.setLevel(logging.INFO)
|
45 |
-
file_handler = logging.handlers.RotatingFileHandler(os.path.join("assets", "logs", "convert.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
46 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
47 |
-
file_handler.setFormatter(file_formatter)
|
48 |
-
file_handler.setLevel(logging.DEBUG)
|
49 |
-
logger.addHandler(console_handler)
|
50 |
-
logger.addHandler(file_handler)
|
51 |
-
logger.setLevel(logging.DEBUG)
|
52 |
-
|
53 |
-
def parse_arguments():
|
54 |
-
parser = argparse.ArgumentParser()
|
55 |
-
parser.add_argument("--pitch", type=int, default=0)
|
56 |
-
parser.add_argument("--filter_radius", type=int, default=3)
|
57 |
-
parser.add_argument("--index_rate", type=float, default=0.5)
|
58 |
-
parser.add_argument("--volume_envelope", type=float, default=1)
|
59 |
-
parser.add_argument("--protect", type=float, default=0.33)
|
60 |
-
parser.add_argument("--hop_length", type=int, default=64)
|
61 |
-
parser.add_argument("--f0_method", type=str, default="rmvpe")
|
62 |
-
parser.add_argument("--embedder_model", type=str, default="contentvec_base")
|
63 |
-
parser.add_argument("--input_path", type=str, required=True)
|
64 |
-
parser.add_argument("--output_path", type=str, default="./audios/output.wav")
|
65 |
-
parser.add_argument("--export_format", type=str, default="wav")
|
66 |
-
parser.add_argument("--pth_path", type=str, required=True)
|
67 |
-
parser.add_argument("--index_path", type=str)
|
68 |
-
parser.add_argument("--f0_autotune", type=lambda x: bool(strtobool(x)), default=False)
|
69 |
-
parser.add_argument("--f0_autotune_strength", type=float, default=1)
|
70 |
-
parser.add_argument("--clean_audio", type=lambda x: bool(strtobool(x)), default=False)
|
71 |
-
parser.add_argument("--clean_strength", type=float, default=0.7)
|
72 |
-
parser.add_argument("--resample_sr", type=int, default=0)
|
73 |
-
parser.add_argument("--split_audio", type=lambda x: bool(strtobool(x)), default=False)
|
74 |
-
parser.add_argument("--checkpointing", type=lambda x: bool(strtobool(x)), default=False)
|
75 |
-
parser.add_argument("--f0_file", type=str, default="")
|
76 |
-
parser.add_argument("--f0_onnx", type=lambda x: bool(strtobool(x)), default=False)
|
77 |
-
parser.add_argument("--embedders_mode", type=str, default="fairseq")
|
78 |
-
parser.add_argument("--formant_shifting", type=lambda x: bool(strtobool(x)), default=False)
|
79 |
-
parser.add_argument("--formant_qfrency", type=float, default=0.8)
|
80 |
-
parser.add_argument("--formant_timbre", type=float, default=0.8)
|
81 |
-
|
82 |
-
return parser.parse_args()
|
83 |
-
|
84 |
-
def main():
|
85 |
-
args = parse_arguments()
|
86 |
-
pitch, filter_radius, index_rate, volume_envelope, protect, hop_length, f0_method, input_path, output_path, pth_path, index_path, f0_autotune, f0_autotune_strength, clean_audio, clean_strength, export_format, embedder_model, resample_sr, split_audio, checkpointing, f0_file, f0_onnx, embedders_mode, formant_shifting, formant_qfrency, formant_timbre = args.pitch, args.filter_radius, args.index_rate, args.volume_envelope,args.protect, args.hop_length, args.f0_method, args.input_path, args.output_path, args.pth_path, args.index_path, args.f0_autotune, args.f0_autotune_strength, args.clean_audio, args.clean_strength, args.export_format, args.embedder_model, args.resample_sr, args.split_audio, args.checkpointing, args.f0_file, args.f0_onnx, args.embedders_mode, args.formant_shifting, args.formant_qfrency, args.formant_timbre
|
87 |
-
|
88 |
-
log_data = {translations['pitch']: pitch, translations['filter_radius']: filter_radius, translations['index_strength']: index_rate, translations['volume_envelope']: volume_envelope, translations['protect']: protect, "Hop length": hop_length, translations['f0_method']: f0_method, translations['audio_path']: input_path, translations['output_path']: output_path.replace('wav', export_format), translations['model_path']: pth_path, translations['indexpath']: index_path, translations['autotune']: f0_autotune, translations['clear_audio']: clean_audio, translations['export_format']: export_format, translations['hubert_model']: embedder_model, translations['split_audio']: split_audio, translations['memory_efficient_training']: checkpointing, translations["f0_onnx_mode"]: f0_onnx, translations["embed_mode"]: embedders_mode}
|
89 |
-
|
90 |
-
if clean_audio: log_data[translations['clean_strength']] = clean_strength
|
91 |
-
if resample_sr != 0: log_data[translations['sample_rate']] = resample_sr
|
92 |
-
|
93 |
-
if f0_autotune: log_data[translations['autotune_rate_info']] = f0_autotune_strength
|
94 |
-
if os.path.isfile(f0_file): log_data[translations['f0_file']] = f0_file
|
95 |
-
|
96 |
-
if formant_shifting:
|
97 |
-
log_data[translations['formant_qfrency']] = formant_qfrency
|
98 |
-
log_data[translations['formant_timbre']] = formant_timbre
|
99 |
-
|
100 |
-
for key, value in log_data.items():
|
101 |
-
logger.debug(f"{key}: {value}")
|
102 |
-
|
103 |
-
run_convert_script(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, input_path=input_path, output_path=output_path, pth_path=pth_path, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, embedder_model=embedder_model, resample_sr=resample_sr, split_audio=split_audio, checkpointing=checkpointing, f0_file=f0_file, f0_onnx=f0_onnx, embedders_mode=embedders_mode, formant_shifting=formant_shifting, formant_qfrency=formant_qfrency, formant_timbre=formant_timbre)
|
104 |
-
|
105 |
-
def run_convert_script(pitch=0, filter_radius=3, index_rate=0.5, volume_envelope=1, protect=0.5, hop_length=64, f0_method="rmvpe", input_path=None, output_path="./output.wav", pth_path=None, index_path=None, f0_autotune=False, f0_autotune_strength=1, clean_audio=False, clean_strength=0.7, export_format="wav", embedder_model="contentvec_base", resample_sr=0, split_audio=False, checkpointing=False, f0_file=None, f0_onnx=False, embedders_mode="fairseq", formant_shifting=False, formant_qfrency=0.8, formant_timbre=0.8):
|
106 |
-
check_predictors(f0_method, f0_onnx); check_embedders(embedder_model, embedders_mode)
|
107 |
-
|
108 |
-
if not pth_path or not os.path.exists(pth_path) or os.path.isdir(pth_path) or not pth_path.endswith((".pth", ".onnx")):
|
109 |
-
logger.warning(translations["provide_file"].format(filename=translations["model"]))
|
110 |
-
sys.exit(1)
|
111 |
-
|
112 |
-
cvt = VoiceConverter(pth_path, 0)
|
113 |
-
start_time = time.time()
|
114 |
-
|
115 |
-
pid_path = os.path.join("assets", "convert_pid.txt")
|
116 |
-
with open(pid_path, "w") as pid_file:
|
117 |
-
pid_file.write(str(os.getpid()))
|
118 |
-
|
119 |
-
if os.path.isdir(input_path):
|
120 |
-
logger.info(translations["convert_batch"])
|
121 |
-
audio_files = [f for f in os.listdir(input_path) if f.lower().endswith(("wav", "mp3", "flac", "ogg", "opus", "m4a", "mp4", "aac", "alac", "wma", "aiff", "webm", "ac3"))]
|
122 |
-
|
123 |
-
if not audio_files:
|
124 |
-
logger.warning(translations["not_found_audio"])
|
125 |
-
sys.exit(1)
|
126 |
-
|
127 |
-
logger.info(translations["found_audio"].format(audio_files=len(audio_files)))
|
128 |
-
|
129 |
-
for audio in audio_files:
|
130 |
-
audio_path = os.path.join(input_path, audio)
|
131 |
-
output_audio = os.path.join(input_path, os.path.splitext(audio)[0] + f"_output.{export_format}")
|
132 |
-
|
133 |
-
logger.info(f"{translations['convert_audio']} '{audio_path}'...")
|
134 |
-
if os.path.exists(output_audio): os.remove(output_audio)
|
135 |
-
cvt.convert_audio(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, audio_input_path=audio_path, audio_output_path=output_audio, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, embedder_model=embedder_model, resample_sr=resample_sr, checkpointing=checkpointing, f0_file=f0_file, f0_onnx=f0_onnx, embedders_mode=embedders_mode, formant_shifting=formant_shifting, formant_qfrency=formant_qfrency, formant_timbre=formant_timbre, split_audio=split_audio)
|
136 |
-
|
137 |
-
logger.info(translations["convert_batch_success"].format(elapsed_time=f"{(time.time() - start_time):.2f}", output_path=output_path.replace('wav', export_format)))
|
138 |
-
else:
|
139 |
-
if not os.path.exists(input_path):
|
140 |
-
logger.warning(translations["not_found_audio"])
|
141 |
-
sys.exit(1)
|
142 |
-
|
143 |
-
logger.info(f"{translations['convert_audio']} '{input_path}'...")
|
144 |
-
if os.path.exists(output_path): os.remove(output_path)
|
145 |
-
cvt.convert_audio(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, audio_input_path=input_path, audio_output_path=output_path, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, embedder_model=embedder_model, resample_sr=resample_sr, checkpointing=checkpointing, f0_file=f0_file, f0_onnx=f0_onnx, embedders_mode=embedders_mode, formant_shifting=formant_shifting, formant_qfrency=formant_qfrency, formant_timbre=formant_timbre, split_audio=split_audio)
|
146 |
-
|
147 |
-
if os.path.exists(pid_path): os.remove(pid_path)
|
148 |
-
logger.info(translations["convert_audio_success"].format(input_path=input_path, elapsed_time=f"{(time.time() - start_time):.2f}", output_path=output_path.replace('wav', export_format)))
|
149 |
-
|
150 |
-
def change_rms(source_audio, source_rate, target_audio, target_rate, rate):
|
151 |
-
rms2 = F.interpolate(torch.from_numpy(librosa.feature.rms(y=target_audio, frame_length=target_rate // 2 * 2, hop_length=target_rate // 2)).float().unsqueeze(0), size=target_audio.shape[0], mode="linear").squeeze()
|
152 |
-
return (target_audio * (torch.pow(F.interpolate(torch.from_numpy(librosa.feature.rms(y=source_audio, frame_length=source_rate // 2 * 2, hop_length=source_rate // 2)).float().unsqueeze(0), size=target_audio.shape[0], mode="linear").squeeze(), 1 - rate) * torch.pow(torch.maximum(rms2, torch.zeros_like(rms2) + 1e-6), rate - 1)).numpy())
|
153 |
-
|
154 |
-
def clear_gpu_cache():
|
155 |
-
gc.collect()
|
156 |
-
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
157 |
-
elif torch.backends.mps.is_available(): torch.mps.empty_cache()
|
158 |
-
|
159 |
-
def get_providers():
|
160 |
-
ort_providers = onnxruntime.get_available_providers()
|
161 |
-
|
162 |
-
if "CUDAExecutionProvider" in ort_providers: providers = ["CUDAExecutionProvider"]
|
163 |
-
elif "CoreMLExecutionProvider" in ort_providers: providers = ["CoreMLExecutionProvider"]
|
164 |
-
else: providers = ["CPUExecutionProvider"]
|
165 |
-
|
166 |
-
return providers
|
167 |
-
|
168 |
-
class Autotune:
|
169 |
-
def __init__(self, ref_freqs):
|
170 |
-
self.ref_freqs = ref_freqs
|
171 |
-
self.note_dict = self.ref_freqs
|
172 |
-
|
173 |
-
def autotune_f0(self, f0, f0_autotune_strength):
|
174 |
-
autotuned_f0 = np.zeros_like(f0)
|
175 |
-
|
176 |
-
for i, freq in enumerate(f0):
|
177 |
-
autotuned_f0[i] = freq + (min(self.note_dict, key=lambda x: abs(x - freq)) - freq) * f0_autotune_strength
|
178 |
-
|
179 |
-
return autotuned_f0
|
180 |
-
|
181 |
-
class VC:
|
182 |
-
def __init__(self, tgt_sr, config):
|
183 |
-
self.x_pad = config.x_pad
|
184 |
-
self.x_query = config.x_query
|
185 |
-
self.x_center = config.x_center
|
186 |
-
self.x_max = config.x_max
|
187 |
-
self.sample_rate = 16000
|
188 |
-
self.window = 160
|
189 |
-
self.t_pad = self.sample_rate * self.x_pad
|
190 |
-
self.t_pad_tgt = tgt_sr * self.x_pad
|
191 |
-
self.t_pad2 = self.t_pad * 2
|
192 |
-
self.t_query = self.sample_rate * self.x_query
|
193 |
-
self.t_center = self.sample_rate * self.x_center
|
194 |
-
self.t_max = self.sample_rate * self.x_max
|
195 |
-
self.time_step = self.window / self.sample_rate * 1000
|
196 |
-
self.f0_min = 50
|
197 |
-
self.f0_max = 1100
|
198 |
-
self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700)
|
199 |
-
self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700)
|
200 |
-
self.device = config.device
|
201 |
-
self.is_half = config.is_half
|
202 |
-
self.ref_freqs = [49.00, 51.91, 55.00, 58.27, 61.74, 65.41, 69.30, 73.42, 77.78, 82.41, 87.31, 92.50, 98.00, 103.83, 110.00, 116.54, 123.47, 130.81, 138.59, 146.83, 155.56, 164.81, 174.61, 185.00, 196.00, 207.65, 220.00, 233.08, 246.94, 261.63, 277.18, 293.66, 311.13, 329.63, 349.23, 369.99, 392.00, 415.30, 440.00, 466.16, 493.88, 523.25, 554.37, 587.33, 622.25, 659.25, 698.46, 739.99, 783.99, 830.61, 880.00, 932.33, 987.77, 1046.50]
|
203 |
-
self.autotune = Autotune(self.ref_freqs)
|
204 |
-
self.note_dict = self.autotune.note_dict
|
205 |
-
|
206 |
-
def get_f0_pm(self, x, p_len):
|
207 |
-
import parselmouth
|
208 |
-
|
209 |
-
f0 = (parselmouth.Sound(x, self.sample_rate).to_pitch_ac(time_step=self.window / self.sample_rate * 1000 / 1000, voicing_threshold=0.6, pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array["frequency"])
|
210 |
-
pad_size = (p_len - len(f0) + 1) // 2
|
211 |
-
|
212 |
-
if pad_size > 0 or p_len - len(f0) - pad_size > 0: f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
|
213 |
-
return f0
|
214 |
-
|
215 |
-
def get_f0_mangio_crepe(self, x, p_len, hop_length, model="full", onnx=False):
|
216 |
-
from main.library.predictors.CREPE import predict
|
217 |
-
|
218 |
-
x = x.astype(np.float32)
|
219 |
-
x /= np.quantile(np.abs(x), 0.999)
|
220 |
-
|
221 |
-
audio = torch.unsqueeze(torch.from_numpy(x).to(self.device, copy=True), dim=0)
|
222 |
-
if audio.ndim == 2 and audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True).detach()
|
223 |
-
|
224 |
-
p_len = p_len or x.shape[0] // hop_length
|
225 |
-
source = np.array(predict(audio.detach(), self.sample_rate, hop_length, self.f0_min, self.f0_max, model, batch_size=hop_length * 2, device=self.device, pad=True, providers=get_providers(), onnx=onnx).squeeze(0).cpu().float().numpy())
|
226 |
-
source[source < 0.001] = np.nan
|
227 |
-
|
228 |
-
return np.nan_to_num(np.interp(np.arange(0, len(source) * p_len, len(source)) / p_len, np.arange(0, len(source)), source))
|
229 |
-
|
230 |
-
def get_f0_crepe(self, x, model="full", onnx=False):
|
231 |
-
from main.library.predictors.CREPE import predict, mean, median
|
232 |
-
|
233 |
-
f0, pd = predict(torch.tensor(np.copy(x))[None].float(), self.sample_rate, self.window, self.f0_min, self.f0_max, model, batch_size=512, device=self.device, return_periodicity=True, providers=get_providers(), onnx=onnx)
|
234 |
-
f0, pd = mean(f0, 3), median(pd, 3)
|
235 |
-
f0[pd < 0.1] = 0
|
236 |
-
|
237 |
-
return f0[0].cpu().numpy()
|
238 |
-
|
239 |
-
def get_f0_fcpe(self, x, p_len, hop_length, onnx=False, legacy=False):
|
240 |
-
from main.library.predictors.FCPE import FCPE
|
241 |
-
|
242 |
-
model_fcpe = FCPE(os.path.join("assets", "models", "predictors", ("fcpe_legacy" if legacy else "fcpe") + (".onnx" if onnx else ".pt")), hop_length=int(hop_length), f0_min=int(self.f0_min), f0_max=int(self.f0_max), dtype=torch.float32, device=self.device, sample_rate=self.sample_rate, threshold=0.03 if legacy else 0.006, providers=get_providers(), onnx=onnx, legacy=legacy)
|
243 |
-
f0 = model_fcpe.compute_f0(x, p_len=p_len)
|
244 |
-
|
245 |
-
del model_fcpe
|
246 |
-
return f0
|
247 |
-
|
248 |
-
def get_f0_rmvpe(self, x, legacy=False, onnx=False):
|
249 |
-
from main.library.predictors.RMVPE import RMVPE
|
250 |
-
|
251 |
-
rmvpe_model = RMVPE(os.path.join("assets", "models", "predictors", "rmvpe" + (".onnx" if onnx else ".pt")), is_half=self.is_half, device=self.device, onnx=onnx, providers=get_providers())
|
252 |
-
f0 = rmvpe_model.infer_from_audio_with_pitch(x, thred=0.03, f0_min=self.f0_min, f0_max=self.f0_max) if legacy else rmvpe_model.infer_from_audio(x, thred=0.03)
|
253 |
-
|
254 |
-
del rmvpe_model
|
255 |
-
return f0
|
256 |
-
|
257 |
-
def get_f0_pyworld(self, x, filter_radius, model="harvest"):
|
258 |
-
from main.library.predictors.WORLD_WRAPPER import PYWORLD
|
259 |
-
|
260 |
-
pw = PYWORLD()
|
261 |
-
x = x.astype(np.double)
|
262 |
-
|
263 |
-
if model == "harvest": f0, t = pw.harvest(x, fs=self.sample_rate, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=10)
|
264 |
-
elif model == "dio": f0, t = pw.dio(x, fs=self.sample_rate, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=10)
|
265 |
-
else: raise ValueError(translations["method_not_valid"])
|
266 |
-
|
267 |
-
f0 = pw.stonemask(x, self.sample_rate, t, f0)
|
268 |
-
|
269 |
-
if filter_radius > 2 or model == "dio": f0 = signal.medfilt(f0, filter_radius)
|
270 |
-
return f0
|
271 |
-
|
272 |
-
def get_f0_swipe(self, x):
|
273 |
-
from main.library.predictors.SWIPE import swipe
|
274 |
-
|
275 |
-
f0, _ = swipe(x.astype(np.float32), self.sample_rate, f0_floor=self.f0_min, f0_ceil=self.f0_max, frame_period=10)
|
276 |
-
return f0
|
277 |
-
|
278 |
-
def get_f0_yin(self, x, hop_length, p_len, mode="yin"):
|
279 |
-
source = np.array(librosa.yin(x.astype(np.float32), sr=self.sample_rate, fmin=self.f0_min, fmax=self.f0_max, hop_length=hop_length) if mode == "yin" else librosa.pyin(x.astype(np.float32), fmin=self.f0_min, fmax=self.f0_max, sr=self.sample_rate, hop_length=hop_length)[0])
|
280 |
-
source[source < 0.001] = np.nan
|
281 |
-
return np.nan_to_num(np.interp(np.arange(0, len(source) * p_len, len(source)) / p_len, np.arange(0, len(source)), source))
|
282 |
-
|
283 |
-
def get_f0_hybrid(self, methods_str, x, p_len, hop_length, filter_radius, onnx_mode):
|
284 |
-
methods_str = re.search("hybrid\[(.+)\]", methods_str)
|
285 |
-
if methods_str: methods = [method.strip() for method in methods_str.group(1).split("+")]
|
286 |
-
|
287 |
-
f0_computation_stack, resampled_stack = [], []
|
288 |
-
logger.debug(translations["hybrid_methods"].format(methods=methods))
|
289 |
-
|
290 |
-
x = x.astype(np.float32)
|
291 |
-
x /= np.quantile(np.abs(x), 0.999)
|
292 |
-
|
293 |
-
for method in methods:
|
294 |
-
f0 = None
|
295 |
-
f0_methods = {"pm": lambda: self.get_f0_pm(x, p_len), "dio": lambda: self.get_f0_pyworld(x, filter_radius, "dio"), "mangio-crepe-tiny": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "tiny", onnx=onnx_mode), "mangio-crepe-small": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "small", onnx=onnx_mode), "mangio-crepe-medium": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "medium", onnx=onnx_mode), "mangio-crepe-large": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "large", onnx=onnx_mode), "mangio-crepe-full": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "full", onnx=onnx_mode), "crepe-tiny": lambda: self.get_f0_crepe(x, "tiny", onnx=onnx_mode), "crepe-small": lambda: self.get_f0_crepe(x, "small", onnx=onnx_mode), "crepe-medium": lambda: self.get_f0_crepe(x, "medium", onnx=onnx_mode), "crepe-large": lambda: self.get_f0_crepe(x, "large", onnx=onnx_mode), "crepe-full": lambda: self.get_f0_crepe(x, "full", onnx=onnx_mode), "fcpe": lambda: self.get_f0_fcpe(x, p_len, int(hop_length), onnx=onnx_mode), "fcpe-legacy": lambda: self.get_f0_fcpe(x, p_len, int(hop_length), legacy=True, onnx=onnx_mode), "rmvpe": lambda: self.get_f0_rmvpe(x, onnx=onnx_mode), "rmvpe-legacy": lambda: self.get_f0_rmvpe(x, legacy=True, onnx=onnx_mode), "harvest": lambda: self.get_f0_pyworld(x, filter_radius, "harvest"), "yin": lambda: self.get_f0_yin(x, int(hop_length), p_len, mode="yin"), "pyin": lambda: self.get_f0_yin(x, int(hop_length), p_len, mode="pyin"), "swipe": lambda: self.get_f0_swipe(x)}
|
296 |
-
f0 = f0_methods.get(method, lambda: ValueError(translations["method_not_valid"]))()
|
297 |
-
f0_computation_stack.append(f0)
|
298 |
-
|
299 |
-
for f0 in f0_computation_stack:
|
300 |
-
resampled_stack.append(np.interp(np.linspace(0, len(f0), p_len), np.arange(len(f0)), f0))
|
301 |
-
|
302 |
-
return resampled_stack[0] if len(resampled_stack) == 1 else np.nanmedian(np.vstack(resampled_stack), axis=0)
|
303 |
-
|
304 |
-
def get_f0(self, x, p_len, pitch, f0_method, filter_radius, hop_length, f0_autotune, f0_autotune_strength, inp_f0=None, onnx_mode=False):
|
305 |
-
f0_methods = {"pm": lambda: self.get_f0_pm(x, p_len), "dio": lambda: self.get_f0_pyworld(x, filter_radius, "dio"), "mangio-crepe-tiny": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "tiny", onnx=onnx_mode), "mangio-crepe-small": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "small", onnx=onnx_mode), "mangio-crepe-medium": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "medium", onnx=onnx_mode), "mangio-crepe-large": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "large", onnx=onnx_mode), "mangio-crepe-full": lambda: self.get_f0_mangio_crepe(x, p_len, int(hop_length), "full", onnx=onnx_mode), "crepe-tiny": lambda: self.get_f0_crepe(x, "tiny", onnx=onnx_mode), "crepe-small": lambda: self.get_f0_crepe(x, "small", onnx=onnx_mode), "crepe-medium": lambda: self.get_f0_crepe(x, "medium", onnx=onnx_mode), "crepe-large": lambda: self.get_f0_crepe(x, "large", onnx=onnx_mode), "crepe-full": lambda: self.get_f0_crepe(x, "full", onnx=onnx_mode), "fcpe": lambda: self.get_f0_fcpe(x, p_len, int(hop_length), onnx=onnx_mode), "fcpe-legacy": lambda: self.get_f0_fcpe(x, p_len, int(hop_length), legacy=True, onnx=onnx_mode), "rmvpe": lambda: self.get_f0_rmvpe(x, onnx=onnx_mode), "rmvpe-legacy": lambda: self.get_f0_rmvpe(x, legacy=True, onnx=onnx_mode), "harvest": lambda: self.get_f0_pyworld(x, filter_radius, "harvest"), "yin": lambda: self.get_f0_yin(x, int(hop_length), p_len, mode="yin"), "pyin": lambda: self.get_f0_yin(x, int(hop_length), p_len, mode="pyin"), "swipe": lambda: self.get_f0_swipe(x)}
|
306 |
-
f0 = self.get_f0_hybrid(f0_method, x, p_len, hop_length, filter_radius, onnx_mode) if "hybrid" in f0_method else f0_methods.get(f0_method, lambda: ValueError(translations["method_not_valid"]))()
|
307 |
-
|
308 |
-
if f0_autotune: f0 = Autotune.autotune_f0(self, f0, f0_autotune_strength)
|
309 |
-
if isinstance(f0, tuple): f0 = f0[0]
|
310 |
-
|
311 |
-
f0 *= pow(2, pitch / 12)
|
312 |
-
tf0 = self.sample_rate // self.window
|
313 |
-
|
314 |
-
if inp_f0 is not None:
|
315 |
-
replace_f0 = np.interp(list(range(np.round((inp_f0[:, 0].max() - inp_f0[:, 0].min()) * tf0 + 1).astype(np.int16))), inp_f0[:, 0] * 100, inp_f0[:, 1])
|
316 |
-
f0[self.x_pad * tf0 : self.x_pad * tf0 + len(replace_f0)] = replace_f0[:f0[self.x_pad * tf0 : self.x_pad * tf0 + len(replace_f0)].shape[0]]
|
317 |
-
|
318 |
-
f0_mel = 1127 * np.log(1 + f0 / 700)
|
319 |
-
f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - self.f0_mel_min) * 254 / (self.f0_mel_max - self.f0_mel_min) + 1
|
320 |
-
f0_mel[f0_mel <= 1] = 1
|
321 |
-
f0_mel[f0_mel > 255] = 255
|
322 |
-
|
323 |
-
return np.rint(f0_mel).astype(np.int32), f0.copy()
|
324 |
-
|
325 |
-
def extract_features(self, model, feats, version):
|
326 |
-
return torch.as_tensor(model.run([model.get_outputs()[0].name, model.get_outputs()[1].name], {"feats": feats.detach().cpu().numpy()})[0 if version == "v1" else 1], dtype=torch.float32, device=feats.device)
|
327 |
-
|
328 |
-
def voice_conversion(self, model, net_g, sid, audio0, pitch, pitchf, index, big_npy, index_rate, version, protect):
|
329 |
-
pitch_guidance = pitch != None and pitchf != None
|
330 |
-
feats = (torch.from_numpy(audio0).half() if self.is_half else torch.from_numpy(audio0).float())
|
331 |
-
|
332 |
-
if feats.dim() == 2: feats = feats.mean(-1)
|
333 |
-
assert feats.dim() == 1, feats.dim()
|
334 |
-
feats = feats.view(1, -1)
|
335 |
-
|
336 |
-
with torch.no_grad():
|
337 |
-
if self.embed_suffix == ".pt":
|
338 |
-
padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False)
|
339 |
-
logits = model.extract_features(**{"source": feats.to(self.device), "padding_mask": padding_mask, "output_layer": 9 if version == "v1" else 12})
|
340 |
-
feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
|
341 |
-
elif self.embed_suffix == ".onnx": feats = self.extract_features(model, feats.to(self.device), version).to(self.device)
|
342 |
-
elif self.embed_suffix == ".safetensors":
|
343 |
-
logits = model(feats.to(self.device))["last_hidden_state"]
|
344 |
-
feats = (model.final_proj(logits[0]).unsqueeze(0) if version == "v1" else logits)
|
345 |
-
else: raise ValueError(translations["option_not_valid"])
|
346 |
-
|
347 |
-
if protect < 0.5 and pitch_guidance: feats0 = feats.clone()
|
348 |
-
|
349 |
-
if (not isinstance(index, type(None)) and not isinstance(big_npy, type(None)) and index_rate != 0):
|
350 |
-
npy = feats[0].cpu().numpy()
|
351 |
-
if self.is_half: npy = npy.astype(np.float32)
|
352 |
-
|
353 |
-
score, ix = index.search(npy, k=8)
|
354 |
-
weight = np.square(1 / score)
|
355 |
-
|
356 |
-
npy = np.sum(big_npy[ix] * np.expand_dims(weight / weight.sum(axis=1, keepdims=True), axis=2), axis=1)
|
357 |
-
if self.is_half: npy = npy.astype(np.float16)
|
358 |
-
|
359 |
-
feats = (torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate + (1 - index_rate) * feats)
|
360 |
-
|
361 |
-
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
|
362 |
-
if protect < 0.5 and pitch_guidance: feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
|
363 |
-
|
364 |
-
p_len = audio0.shape[0] // self.window
|
365 |
-
|
366 |
-
if feats.shape[1] < p_len:
|
367 |
-
p_len = feats.shape[1]
|
368 |
-
if pitch_guidance:
|
369 |
-
pitch = pitch[:, :p_len]
|
370 |
-
pitchf = pitchf[:, :p_len]
|
371 |
-
|
372 |
-
if protect < 0.5 and pitch_guidance:
|
373 |
-
pitchff = pitchf.clone()
|
374 |
-
pitchff[pitchf > 0] = 1
|
375 |
-
pitchff[pitchf < 1] = protect
|
376 |
-
pitchff = pitchff.unsqueeze(-1)
|
377 |
-
|
378 |
-
feats = (feats * pitchff + feats0 * (1 - pitchff)).to(feats0.dtype)
|
379 |
-
|
380 |
-
p_len = torch.tensor([p_len], device=self.device).long()
|
381 |
-
audio1 = ((net_g.infer(feats.half() if self.is_half else feats.float(), p_len, pitch if pitch_guidance else None, (pitchf.half() if self.is_half else pitchf.float()) if pitch_guidance else None, sid)[0][0, 0]).data.cpu().float().numpy()) if self.suffix == ".pth" else (net_g.run([net_g.get_outputs()[0].name], ({net_g.get_inputs()[0].name: feats.cpu().numpy().astype(np.float32), net_g.get_inputs()[1].name: p_len.cpu().numpy(), net_g.get_inputs()[2].name: np.array([sid.cpu().item()], dtype=np.int64), net_g.get_inputs()[3].name: np.random.randn(1, 192, p_len).astype(np.float32), net_g.get_inputs()[4].name: pitch.cpu().numpy().astype(np.int64), net_g.get_inputs()[5].name: pitchf.cpu().numpy().astype(np.float32)} if pitch_guidance else {net_g.get_inputs()[0].name: feats.cpu().numpy().astype(np.float32), net_g.get_inputs()[1].name: p_len.cpu().numpy(), net_g.get_inputs()[2].name: np.array([sid.cpu().item()], dtype=np.int64), net_g.get_inputs()[3].name: np.random.randn(1, 192, p_len).astype(np.float32)}))[0][0, 0])
|
382 |
-
|
383 |
-
if self.embed_suffix == ".pt": del padding_mask
|
384 |
-
del feats, p_len, net_g
|
385 |
-
clear_gpu_cache()
|
386 |
-
return audio1
|
387 |
-
|
388 |
-
def pipeline(self, model, net_g, sid, audio, pitch, f0_method, file_index, index_rate, pitch_guidance, filter_radius, volume_envelope, version, protect, hop_length, f0_autotune, f0_autotune_strength, suffix, embed_suffix, f0_file=None, f0_onnx=False, pbar=None):
|
389 |
-
self.suffix = suffix
|
390 |
-
self.embed_suffix = embed_suffix
|
391 |
-
|
392 |
-
if file_index != "" and os.path.exists(file_index) and index_rate != 0:
|
393 |
-
try:
|
394 |
-
index = faiss.read_index(file_index)
|
395 |
-
big_npy = index.reconstruct_n(0, index.ntotal)
|
396 |
-
except Exception as e:
|
397 |
-
logger.error(translations["read_faiss_index_error"].format(e=e))
|
398 |
-
index = big_npy = None
|
399 |
-
else: index = big_npy = None
|
400 |
-
|
401 |
-
pbar.update(1)
|
402 |
-
opt_ts, audio_opt = [], []
|
403 |
-
audio = signal.filtfilt(bh, ah, audio)
|
404 |
-
audio_pad = np.pad(audio, (self.window // 2, self.window // 2), mode="reflect")
|
405 |
-
|
406 |
-
if audio_pad.shape[0] > self.t_max:
|
407 |
-
audio_sum = np.zeros_like(audio)
|
408 |
-
for i in range(self.window):
|
409 |
-
audio_sum += audio_pad[i : i - self.window]
|
410 |
-
|
411 |
-
for t in range(self.t_center, audio.shape[0], self.t_center):
|
412 |
-
opt_ts.append(t - self.t_query + np.where(np.abs(audio_sum[t - self.t_query : t + self.t_query]) == np.abs(audio_sum[t - self.t_query : t + self.t_query]).min())[0][0])
|
413 |
-
|
414 |
-
s = 0
|
415 |
-
t, inp_f0 = None, None
|
416 |
-
audio_pad = np.pad(audio, (self.t_pad, self.t_pad), mode="reflect")
|
417 |
-
sid = torch.tensor(sid, device=self.device).unsqueeze(0).long()
|
418 |
-
p_len = audio_pad.shape[0] // self.window
|
419 |
-
|
420 |
-
if hasattr(f0_file, "name"):
|
421 |
-
try:
|
422 |
-
with open(f0_file.name, "r") as f:
|
423 |
-
raw_lines = f.read()
|
424 |
-
if len(raw_lines) > 0:
|
425 |
-
inp_f0 = []
|
426 |
-
for line in raw_lines.strip("\n").split("\n"):
|
427 |
-
inp_f0.append([float(i) for i in line.split(",")])
|
428 |
-
|
429 |
-
inp_f0 = np.array(inp_f0, dtype=np.float32)
|
430 |
-
except:
|
431 |
-
logger.error(translations["error_readfile"])
|
432 |
-
inp_f0 = None
|
433 |
-
|
434 |
-
pbar.update(1)
|
435 |
-
if pitch_guidance:
|
436 |
-
pitch, pitchf = self.get_f0(audio_pad, p_len, pitch, f0_method, filter_radius, hop_length, f0_autotune, f0_autotune_strength, inp_f0, onnx_mode=f0_onnx)
|
437 |
-
pitch, pitchf = pitch[:p_len], pitchf[:p_len]
|
438 |
-
if self.device == "mps": pitchf = pitchf.astype(np.float32)
|
439 |
-
pitch, pitchf = torch.tensor(pitch, device=self.device).unsqueeze(0).long(), torch.tensor(pitchf, device=self.device).unsqueeze(0).float()
|
440 |
-
|
441 |
-
pbar.update(1)
|
442 |
-
for t in opt_ts:
|
443 |
-
t = t // self.window * self.window
|
444 |
-
audio_opt.append(self.voice_conversion(model, net_g, sid, audio_pad[s : t + self.t_pad2 + self.window], pitch[:, s // self.window : (t + self.t_pad2) // self.window] if pitch_guidance else None, pitchf[:, s // self.window : (t + self.t_pad2) // self.window] if pitch_guidance else None, index, big_npy, index_rate, version, protect)[self.t_pad_tgt : -self.t_pad_tgt])
|
445 |
-
s = t
|
446 |
-
|
447 |
-
audio_opt.append(self.voice_conversion(model, net_g, sid, audio_pad[t:], (pitch[:, t // self.window :] if t is not None else pitch) if pitch_guidance else None, (pitchf[:, t // self.window :] if t is not None else pitchf) if pitch_guidance else None, index, big_npy, index_rate, version, protect)[self.t_pad_tgt : -self.t_pad_tgt])
|
448 |
-
audio_opt = np.concatenate(audio_opt)
|
449 |
-
if volume_envelope != 1: audio_opt = change_rms(audio, self.sample_rate, audio_opt, self.sample_rate, volume_envelope)
|
450 |
-
audio_max = np.abs(audio_opt).max() / 0.99
|
451 |
-
if audio_max > 1: audio_opt /= audio_max
|
452 |
-
|
453 |
-
if pitch_guidance: del pitch, pitchf
|
454 |
-
del sid
|
455 |
-
clear_gpu_cache()
|
456 |
-
pbar.update(1)
|
457 |
-
|
458 |
-
return audio_opt
|
459 |
-
|
460 |
-
class VoiceConverter:
|
461 |
-
def __init__(self, model_path, sid = 0):
|
462 |
-
self.config = config
|
463 |
-
self.device = config.device
|
464 |
-
self.hubert_model = None
|
465 |
-
self.tgt_sr = None
|
466 |
-
self.net_g = None
|
467 |
-
self.vc = None
|
468 |
-
self.cpt = None
|
469 |
-
self.version = None
|
470 |
-
self.n_spk = None
|
471 |
-
self.use_f0 = None
|
472 |
-
self.loaded_model = None
|
473 |
-
self.vocoder = "Default"
|
474 |
-
self.checkpointing = False
|
475 |
-
self.sample_rate = 16000
|
476 |
-
self.sid = sid
|
477 |
-
self.get_vc(model_path, sid)
|
478 |
-
|
479 |
-
def convert_audio(self, audio_input_path, audio_output_path, index_path, embedder_model, pitch, f0_method, index_rate, volume_envelope, protect, hop_length, f0_autotune, f0_autotune_strength, filter_radius, clean_audio, clean_strength, export_format, resample_sr = 0, checkpointing = False, f0_file = None, f0_onnx = False, embedders_mode = "fairseq", formant_shifting = False, formant_qfrency = 0.8, formant_timbre = 0.8, split_audio = False):
|
480 |
-
try:
|
481 |
-
with tqdm(total=10, desc=translations["convert_audio"], ncols=100, unit="a") as pbar:
|
482 |
-
audio = load_audio(logger, audio_input_path, self.sample_rate, formant_shifting=formant_shifting, formant_qfrency=formant_qfrency, formant_timbre=formant_timbre)
|
483 |
-
self.checkpointing = checkpointing
|
484 |
-
audio_max = np.abs(audio).max() / 0.95
|
485 |
-
if audio_max > 1: audio /= audio_max
|
486 |
-
|
487 |
-
pbar.update(1)
|
488 |
-
if not self.hubert_model:
|
489 |
-
models, _, embed_suffix = load_embedders_model(embedder_model, embedders_mode, providers=get_providers())
|
490 |
-
self.hubert_model = (models.to(self.device).half() if self.config.is_half else models.to(self.device).float()).eval() if embed_suffix in [".pt", ".safetensors"] else models
|
491 |
-
self.embed_suffix = embed_suffix
|
492 |
-
|
493 |
-
pbar.update(1)
|
494 |
-
if self.tgt_sr != resample_sr >= self.sample_rate: self.tgt_sr = resample_sr
|
495 |
-
target_sr = min([8000, 11025, 12000, 16000, 22050, 24000, 32000, 44100, 48000, 96000], key=lambda x: abs(x - self.tgt_sr))
|
496 |
-
|
497 |
-
if split_audio:
|
498 |
-
chunks = cut(audio, self.sample_rate, db_thresh=-60, min_interval=500)
|
499 |
-
pbar.total = len(chunks) * 4 + 6
|
500 |
-
logger.info(f"{translations['split_total']}: {len(chunks)}")
|
501 |
-
else: chunks = [(audio, 0, 0)]
|
502 |
-
|
503 |
-
converted_chunks = []
|
504 |
-
pbar.update(1)
|
505 |
-
|
506 |
-
for waveform, start, end in chunks:
|
507 |
-
converted_chunks.append((start, end, self.vc.pipeline(model=self.hubert_model, net_g=self.net_g, sid=self.sid, audio=waveform, pitch=pitch, f0_method=f0_method, file_index=(index_path.strip().strip('"').strip("\n").strip('"').strip().replace("trained", "added")), index_rate=index_rate, pitch_guidance=self.use_f0, filter_radius=filter_radius, volume_envelope=volume_envelope, version=self.version, protect=protect, hop_length=hop_length, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, suffix=self.suffix, embed_suffix=self.embed_suffix, f0_file=f0_file, f0_onnx=f0_onnx, pbar=pbar)))
|
508 |
-
|
509 |
-
pbar.update(1)
|
510 |
-
audio_output = restore(converted_chunks, total_len=len(audio), dtype=converted_chunks[0][2].dtype) if split_audio else converted_chunks[0][2]
|
511 |
-
if target_sr >= self.sample_rate and self.tgt_sr != target_sr: audio_output = librosa.resample(audio_output, orig_sr=self.tgt_sr, target_sr=target_sr, res_type="soxr_vhq")
|
512 |
-
|
513 |
-
pbar.update(1)
|
514 |
-
if clean_audio:
|
515 |
-
from main.tools.noisereduce import reduce_noise
|
516 |
-
audio_output = reduce_noise(y=audio_output, sr=target_sr, prop_decrease=clean_strength, device=self.device)
|
517 |
-
|
518 |
-
sf.write(audio_output_path, audio_output, target_sr, format=export_format)
|
519 |
-
pbar.update(1)
|
520 |
-
except Exception as e:
|
521 |
-
logger.error(translations["error_convert"].format(e=e))
|
522 |
-
import traceback
|
523 |
-
logger.debug(traceback.format_exc())
|
524 |
-
|
525 |
-
def get_vc(self, weight_root, sid):
|
526 |
-
if sid == "" or sid == []:
|
527 |
-
self.cleanup()
|
528 |
-
clear_gpu_cache()
|
529 |
-
|
530 |
-
if not self.loaded_model or self.loaded_model != weight_root:
|
531 |
-
self.loaded_model = weight_root
|
532 |
-
self.load_model()
|
533 |
-
if self.cpt is not None: self.setup()
|
534 |
-
|
535 |
-
def cleanup(self):
|
536 |
-
if self.hubert_model is not None:
|
537 |
-
del self.net_g, self.n_spk, self.vc, self.hubert_model, self.tgt_sr
|
538 |
-
self.hubert_model = self.net_g = self.n_spk = self.vc = self.tgt_sr = None
|
539 |
-
clear_gpu_cache()
|
540 |
-
|
541 |
-
del self.net_g, self.cpt
|
542 |
-
clear_gpu_cache()
|
543 |
-
self.cpt = None
|
544 |
-
|
545 |
-
def load_model(self):
|
546 |
-
if os.path.isfile(self.loaded_model):
|
547 |
-
if self.loaded_model.endswith(".pth"): self.cpt = torch.load(self.loaded_model, map_location="cpu")
|
548 |
-
else:
|
549 |
-
sess_options = onnxruntime.SessionOptions()
|
550 |
-
sess_options.log_severity_level = 3
|
551 |
-
self.cpt = onnxruntime.InferenceSession(self.loaded_model, sess_options=sess_options, providers=get_providers())
|
552 |
-
else: self.cpt = None
|
553 |
-
|
554 |
-
def setup(self):
|
555 |
-
if self.cpt is not None:
|
556 |
-
if self.loaded_model.endswith(".pth"):
|
557 |
-
self.tgt_sr = self.cpt["config"][-1]
|
558 |
-
self.cpt["config"][-3] = self.cpt["weight"]["emb_g.weight"].shape[0]
|
559 |
-
self.use_f0 = self.cpt.get("f0", 1)
|
560 |
-
self.version = self.cpt.get("version", "v1")
|
561 |
-
self.vocoder = self.cpt.get("vocoder", "Default")
|
562 |
-
if self.vocoder != "Default": self.config.is_half = False
|
563 |
-
|
564 |
-
self.net_g = Synthesizer(*self.cpt["config"], use_f0=self.use_f0, text_enc_hidden_dim=768 if self.version == "v2" else 256, vocoder=self.vocoder, checkpointing=self.checkpointing)
|
565 |
-
del self.net_g.enc_q
|
566 |
-
|
567 |
-
self.net_g.load_state_dict(self.cpt["weight"], strict=False)
|
568 |
-
self.net_g.eval().to(self.device)
|
569 |
-
self.net_g = (self.net_g.half() if self.config.is_half else self.net_g.float())
|
570 |
-
self.n_spk = self.cpt["config"][-3]
|
571 |
-
self.suffix = ".pth"
|
572 |
-
else:
|
573 |
-
import json
|
574 |
-
import onnx
|
575 |
-
|
576 |
-
metadata_dict = None
|
577 |
-
for prop in onnx.load(self.loaded_model).metadata_props:
|
578 |
-
if prop.key == "model_info":
|
579 |
-
metadata_dict = json.loads(prop.value)
|
580 |
-
break
|
581 |
-
|
582 |
-
self.net_g = self.cpt
|
583 |
-
self.tgt_sr = metadata_dict.get("sr", 32000)
|
584 |
-
self.use_f0 = metadata_dict.get("f0", 1)
|
585 |
-
self.version = metadata_dict.get("version", "v1")
|
586 |
-
self.suffix = ".onnx"
|
587 |
-
|
588 |
-
self.vc = VC(self.tgt_sr, self.config)
|
589 |
-
|
590 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/create_dataset.py
DELETED
@@ -1,230 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import time
|
4 |
-
import yt_dlp
|
5 |
-
import shutil
|
6 |
-
import librosa
|
7 |
-
import logging
|
8 |
-
import argparse
|
9 |
-
import warnings
|
10 |
-
import logging.handlers
|
11 |
-
|
12 |
-
from soundfile import read, write
|
13 |
-
from distutils.util import strtobool
|
14 |
-
|
15 |
-
sys.path.append(os.getcwd())
|
16 |
-
|
17 |
-
from main.configs.config import Config
|
18 |
-
from main.library.algorithm.separator import Separator
|
19 |
-
|
20 |
-
config = Config()
|
21 |
-
translations = config.translations
|
22 |
-
dataset_temp = os.path.join("dataset_temp")
|
23 |
-
logger = logging.getLogger(__name__)
|
24 |
-
|
25 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
26 |
-
else:
|
27 |
-
console_handler = logging.StreamHandler()
|
28 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
29 |
-
console_handler.setFormatter(console_formatter)
|
30 |
-
console_handler.setLevel(logging.INFO)
|
31 |
-
file_handler = logging.handlers.RotatingFileHandler(os.path.join("assets", "logs", "create_dataset.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
32 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
33 |
-
file_handler.setFormatter(file_formatter)
|
34 |
-
file_handler.setLevel(logging.DEBUG)
|
35 |
-
logger.addHandler(console_handler)
|
36 |
-
logger.addHandler(file_handler)
|
37 |
-
logger.setLevel(logging.DEBUG)
|
38 |
-
|
39 |
-
def parse_arguments():
|
40 |
-
parser = argparse.ArgumentParser()
|
41 |
-
parser.add_argument("--input_audio", type=str, required=True)
|
42 |
-
parser.add_argument("--output_dataset", type=str, default="./dataset")
|
43 |
-
parser.add_argument("--sample_rate", type=int, default=44100)
|
44 |
-
parser.add_argument("--clean_dataset", type=lambda x: bool(strtobool(x)), default=False)
|
45 |
-
parser.add_argument("--clean_strength", type=float, default=0.7)
|
46 |
-
parser.add_argument("--separator_reverb", type=lambda x: bool(strtobool(x)), default=False)
|
47 |
-
parser.add_argument("--kim_vocal_version", type=int, default=2)
|
48 |
-
parser.add_argument("--overlap", type=float, default=0.25)
|
49 |
-
parser.add_argument("--segments_size", type=int, default=256)
|
50 |
-
parser.add_argument("--mdx_hop_length", type=int, default=1024)
|
51 |
-
parser.add_argument("--mdx_batch_size", type=int, default=1)
|
52 |
-
parser.add_argument("--denoise_mdx", type=lambda x: bool(strtobool(x)), default=False)
|
53 |
-
parser.add_argument("--skip", type=lambda x: bool(strtobool(x)), default=False)
|
54 |
-
parser.add_argument("--skip_start_audios", type=str, default="0")
|
55 |
-
parser.add_argument("--skip_end_audios", type=str, default="0")
|
56 |
-
|
57 |
-
return parser.parse_args()
|
58 |
-
|
59 |
-
def main():
|
60 |
-
pid_path = os.path.join("assets", "create_dataset_pid.txt")
|
61 |
-
with open(pid_path, "w") as pid_file:
|
62 |
-
pid_file.write(str(os.getpid()))
|
63 |
-
|
64 |
-
args = parse_arguments()
|
65 |
-
input_audio, output_dataset, sample_rate, clean_dataset, clean_strength, separator_reverb, kim_vocal_version, overlap, segments_size, hop_length, batch_size, denoise_mdx, skip, skip_start_audios, skip_end_audios = args.input_audio, args.output_dataset, args.sample_rate, args.clean_dataset, args.clean_strength, args.separator_reverb, args.kim_vocal_version, args.overlap, args.segments_size, args.mdx_hop_length, args.mdx_batch_size, args.denoise_mdx, args.skip, args.skip_start_audios, args.skip_end_audios
|
66 |
-
log_data = {translations['audio_path']: input_audio, translations['output_path']: output_dataset, translations['sr']: sample_rate, translations['clear_dataset']: clean_dataset, translations['dereveb_audio']: separator_reverb, translations['segments_size']: segments_size, translations['overlap']: overlap, "Hop length": hop_length, translations['batch_size']: batch_size, translations['denoise_mdx']: denoise_mdx, translations['skip']: skip}
|
67 |
-
|
68 |
-
if clean_dataset: log_data[translations['clean_strength']] = clean_strength
|
69 |
-
if skip:
|
70 |
-
log_data[translations['skip_start']] = skip_start_audios
|
71 |
-
log_data[translations['skip_end']] = skip_end_audios
|
72 |
-
|
73 |
-
for key, value in log_data.items():
|
74 |
-
logger.debug(f"{key}: {value}")
|
75 |
-
|
76 |
-
if kim_vocal_version not in [1, 2]: raise ValueError(translations["version_not_valid"])
|
77 |
-
start_time = time.time()
|
78 |
-
|
79 |
-
try:
|
80 |
-
paths = []
|
81 |
-
|
82 |
-
if not os.path.exists(dataset_temp): os.makedirs(dataset_temp, exist_ok=True)
|
83 |
-
urls = input_audio.replace(", ", ",").split(",")
|
84 |
-
|
85 |
-
for url in urls:
|
86 |
-
path = downloader(url, urls.index(url))
|
87 |
-
paths.append(path)
|
88 |
-
|
89 |
-
if skip:
|
90 |
-
skip_start_audios, skip_end_audios = skip_start_audios.replace(", ", ",").split(","), skip_end_audios.replace(", ", ",").split(",")
|
91 |
-
|
92 |
-
if len(skip_start_audios) < len(paths) or len(skip_end_audios) < len(paths):
|
93 |
-
logger.warning(translations["skip<audio"])
|
94 |
-
sys.exit(1)
|
95 |
-
elif len(skip_start_audios) > len(paths) or len(skip_end_audios) > len(paths):
|
96 |
-
logger.warning(translations["skip>audio"])
|
97 |
-
sys.exit(1)
|
98 |
-
else:
|
99 |
-
for audio, skip_start_audio, skip_end_audio in zip(paths, skip_start_audios, skip_end_audios):
|
100 |
-
skip_start(audio, skip_start_audio)
|
101 |
-
skip_end(audio, skip_end_audio)
|
102 |
-
|
103 |
-
separator_paths = []
|
104 |
-
|
105 |
-
for audio in paths:
|
106 |
-
vocals = separator_music_main(audio, dataset_temp, segments_size, overlap, denoise_mdx, kim_vocal_version, hop_length, batch_size, sample_rate)
|
107 |
-
if separator_reverb: vocals = separator_reverb_audio(vocals, dataset_temp, segments_size, overlap, denoise_mdx, hop_length, batch_size, sample_rate)
|
108 |
-
separator_paths.append(vocals)
|
109 |
-
|
110 |
-
paths = separator_paths
|
111 |
-
|
112 |
-
for audio_path in paths:
|
113 |
-
data, sample_rate = read(audio_path)
|
114 |
-
data = librosa.to_mono(data.T)
|
115 |
-
|
116 |
-
if clean_dataset:
|
117 |
-
from main.tools.noisereduce import reduce_noise
|
118 |
-
data = reduce_noise(y=data, prop_decrease=clean_strength, device=config.device)
|
119 |
-
|
120 |
-
write(audio_path, data, sample_rate)
|
121 |
-
except Exception as e:
|
122 |
-
logger.error(f"{translations['create_dataset_error']}: {e}")
|
123 |
-
import traceback
|
124 |
-
logger.error(traceback.format_exc())
|
125 |
-
finally:
|
126 |
-
for audio in paths:
|
127 |
-
shutil.move(audio, output_dataset)
|
128 |
-
|
129 |
-
if os.path.exists(dataset_temp): shutil.rmtree(dataset_temp, ignore_errors=True)
|
130 |
-
|
131 |
-
elapsed_time = time.time() - start_time
|
132 |
-
if os.path.exists(pid_path): os.remove(pid_path)
|
133 |
-
logger.info(translations["create_dataset_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
|
134 |
-
|
135 |
-
def downloader(url, name):
|
136 |
-
with warnings.catch_warnings():
|
137 |
-
warnings.simplefilter("ignore")
|
138 |
-
|
139 |
-
ydl_opts = {"format": "bestaudio/best", "outtmpl": os.path.join(dataset_temp, f"{name}"), "postprocessors": [{"key": "FFmpegExtractAudio", "preferredcodec": "wav", "preferredquality": "192"}], "no_warnings": True, "noplaylist": True, "noplaylist": True, "verbose": False}
|
140 |
-
logger.info(f"{translations['starting_download']}: {url}...")
|
141 |
-
|
142 |
-
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
143 |
-
ydl.extract_info(url)
|
144 |
-
logger.info(f"{translations['download_success']}: {url}")
|
145 |
-
|
146 |
-
return os.path.join(dataset_temp, f"{name}" + ".wav")
|
147 |
-
|
148 |
-
def skip_start(input_file, seconds):
|
149 |
-
data, sr = read(input_file)
|
150 |
-
total_duration = len(data) / sr
|
151 |
-
|
152 |
-
if seconds <= 0: logger.warning(translations["=<0"])
|
153 |
-
elif seconds >= total_duration: logger.warning(translations["skip_warning"].format(seconds=seconds, total_duration=f"{total_duration:.2f}"))
|
154 |
-
else:
|
155 |
-
logger.info(f"{translations['skip_start']}: {input_file}...")
|
156 |
-
write(input_file, data[int(seconds * sr):], sr)
|
157 |
-
|
158 |
-
logger.info(translations["skip_start_audio"].format(input_file=input_file))
|
159 |
-
|
160 |
-
def skip_end(input_file, seconds):
|
161 |
-
data, sr = read(input_file)
|
162 |
-
total_duration = len(data) / sr
|
163 |
-
|
164 |
-
if seconds <= 0: logger.warning(translations["=<0"])
|
165 |
-
elif seconds > total_duration: logger.warning(translations["skip_warning"].format(seconds=seconds, total_duration=f"{total_duration:.2f}"))
|
166 |
-
else:
|
167 |
-
logger.info(f"{translations['skip_end']}: {input_file}...")
|
168 |
-
write(input_file, data[:-int(seconds * sr)], sr)
|
169 |
-
|
170 |
-
logger.info(translations["skip_end_audio"].format(input_file=input_file))
|
171 |
-
|
172 |
-
def separator_music_main(input, output, segments_size, overlap, denoise, version, hop_length, batch_size, sample_rate):
|
173 |
-
if not os.path.exists(input):
|
174 |
-
logger.warning(translations["input_not_valid"])
|
175 |
-
return None
|
176 |
-
|
177 |
-
if not os.path.exists(output):
|
178 |
-
logger.warning(translations["output_not_valid"])
|
179 |
-
return None
|
180 |
-
|
181 |
-
model = f"Kim_Vocal_{version}.onnx"
|
182 |
-
output_separator = separator_main(audio_file=input, model_filename=model, output_format="wav", output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise, sample_rate=sample_rate)
|
183 |
-
|
184 |
-
for f in output_separator:
|
185 |
-
path = os.path.join(output, f)
|
186 |
-
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
|
187 |
-
|
188 |
-
if '_(Instrumental)_' in f: os.rename(path, os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav")
|
189 |
-
elif '_(Vocals)_' in f:
|
190 |
-
rename_file = os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav"
|
191 |
-
os.rename(path, rename_file)
|
192 |
-
|
193 |
-
return rename_file
|
194 |
-
|
195 |
-
def separator_reverb_audio(input, output, segments_size, overlap, denoise, hop_length, batch_size, sample_rate):
|
196 |
-
if not os.path.exists(input):
|
197 |
-
logger.warning(translations["input_not_valid"])
|
198 |
-
return None
|
199 |
-
|
200 |
-
if not os.path.exists(output):
|
201 |
-
logger.warning(translations["output_not_valid"])
|
202 |
-
return None
|
203 |
-
|
204 |
-
logger.info(f"{translations['dereverb']}: {input}...")
|
205 |
-
output_dereverb = separator_main(audio_file=input, model_filename="Reverb_HQ_By_FoxJoy.onnx", output_format="wav", output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=hop_length, mdx_hop_length=batch_size, mdx_enable_denoise=denoise, sample_rate=sample_rate)
|
206 |
-
|
207 |
-
for f in output_dereverb:
|
208 |
-
path = os.path.join(output, f)
|
209 |
-
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
|
210 |
-
|
211 |
-
if '_(Reverb)_' in f: os.rename(path, os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav")
|
212 |
-
elif '_(No Reverb)_' in f:
|
213 |
-
rename_file = os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav"
|
214 |
-
os.rename(path, rename_file)
|
215 |
-
|
216 |
-
logger.info(f"{translations['dereverb_success']}: {rename_file}")
|
217 |
-
return rename_file
|
218 |
-
|
219 |
-
def separator_main(audio_file=None, model_filename="Kim_Vocal_1.onnx", output_format="wav", output_dir=".", mdx_segment_size=256, mdx_overlap=0.25, mdx_batch_size=1, mdx_hop_length=1024, mdx_enable_denoise=True, sample_rate=44100):
|
220 |
-
try:
|
221 |
-
separator = Separator(logger=logger, log_formatter=file_formatter, log_level=logging.INFO, output_dir=output_dir, output_format=output_format, output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=sample_rate, mdx_params={"hop_length": mdx_hop_length, "segment_size": mdx_segment_size, "overlap": mdx_overlap, "batch_size": mdx_batch_size, "enable_denoise": mdx_enable_denoise})
|
222 |
-
separator.load_model(model_filename=model_filename)
|
223 |
-
return separator.separate(audio_file)
|
224 |
-
except:
|
225 |
-
logger.debug(translations["default_setting"])
|
226 |
-
separator = Separator(logger=logger, log_formatter=file_formatter, log_level=logging.INFO, output_dir=output_dir, output_format=output_format, output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=44100, mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": mdx_enable_denoise})
|
227 |
-
separator.load_model(model_filename=model_filename)
|
228 |
-
return separator.separate(audio_file)
|
229 |
-
|
230 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/create_index.py
DELETED
@@ -1,90 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import faiss
|
4 |
-
import logging
|
5 |
-
import argparse
|
6 |
-
import logging.handlers
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
-
|
10 |
-
from multiprocessing import cpu_count
|
11 |
-
from sklearn.cluster import MiniBatchKMeans
|
12 |
-
|
13 |
-
sys.path.append(os.getcwd())
|
14 |
-
|
15 |
-
from main.configs.config import Config
|
16 |
-
translations = Config().translations
|
17 |
-
|
18 |
-
def parse_arguments():
|
19 |
-
parser = argparse.ArgumentParser()
|
20 |
-
parser.add_argument("--model_name", type=str, required=True)
|
21 |
-
parser.add_argument("--rvc_version", type=str, default="v2")
|
22 |
-
parser.add_argument("--index_algorithm", type=str, default="Auto")
|
23 |
-
|
24 |
-
return parser.parse_args()
|
25 |
-
|
26 |
-
def main():
|
27 |
-
args = parse_arguments()
|
28 |
-
exp_dir = os.path.join("assets", "logs", args.model_name)
|
29 |
-
version, index_algorithm = args.rvc_version, args.index_algorithm
|
30 |
-
logger = logging.getLogger(__name__)
|
31 |
-
|
32 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
33 |
-
else:
|
34 |
-
console_handler = logging.StreamHandler()
|
35 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
36 |
-
console_handler.setFormatter(console_formatter)
|
37 |
-
console_handler.setLevel(logging.INFO)
|
38 |
-
file_handler = logging.handlers.RotatingFileHandler(os.path.join(exp_dir, "create_index.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
39 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
40 |
-
file_handler.setFormatter(file_formatter)
|
41 |
-
file_handler.setLevel(logging.DEBUG)
|
42 |
-
logger.addHandler(console_handler)
|
43 |
-
logger.addHandler(file_handler)
|
44 |
-
logger.setLevel(logging.DEBUG)
|
45 |
-
|
46 |
-
log_data = {translations['modelname']: args.model_name, translations['model_path']: exp_dir, translations['training_version']: version, translations['index_algorithm_info']: index_algorithm}
|
47 |
-
for key, value in log_data.items():
|
48 |
-
logger.debug(f"{key}: {value}")
|
49 |
-
|
50 |
-
try:
|
51 |
-
npys = []
|
52 |
-
feature_dir = os.path.join(exp_dir, f"{version}_extracted")
|
53 |
-
model_name = os.path.basename(exp_dir)
|
54 |
-
|
55 |
-
for name in sorted(os.listdir(feature_dir)):
|
56 |
-
npys.append(np.load(os.path.join(feature_dir, name)))
|
57 |
-
|
58 |
-
big_npy = np.concatenate(npys, axis=0)
|
59 |
-
big_npy_idx = np.arange(big_npy.shape[0])
|
60 |
-
np.random.shuffle(big_npy_idx)
|
61 |
-
big_npy = big_npy[big_npy_idx]
|
62 |
-
|
63 |
-
if big_npy.shape[0] > 2e5 and (index_algorithm == "Auto" or index_algorithm == "KMeans"): big_npy = (MiniBatchKMeans(n_clusters=10000, verbose=True, batch_size=256 * cpu_count(), compute_labels=False, init="random").fit(big_npy).cluster_centers_)
|
64 |
-
np.save(os.path.join(exp_dir, "total_fea.npy"), big_npy)
|
65 |
-
|
66 |
-
n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39)
|
67 |
-
index_trained = faiss.index_factory(256 if version == "v1" else 768, f"IVF{n_ivf},Flat")
|
68 |
-
index_ivf_trained = faiss.extract_index_ivf(index_trained)
|
69 |
-
index_ivf_trained.nprobe = 1
|
70 |
-
index_trained.train(big_npy)
|
71 |
-
faiss.write_index(index_trained, os.path.join(exp_dir, f"trained_IVF{n_ivf}_Flat_nprobe_{index_ivf_trained.nprobe}_{model_name}_{version}.index"))
|
72 |
-
|
73 |
-
index_added = faiss.index_factory(256 if version == "v1" else 768, f"IVF{n_ivf},Flat")
|
74 |
-
index_ivf_added = faiss.extract_index_ivf(index_added)
|
75 |
-
index_ivf_added.nprobe = 1
|
76 |
-
index_added.train(big_npy)
|
77 |
-
batch_size_add = 8192
|
78 |
-
|
79 |
-
for i in range(0, big_npy.shape[0], batch_size_add):
|
80 |
-
index_added.add(big_npy[i : i + batch_size_add])
|
81 |
-
|
82 |
-
index_filepath_added = os.path.join(exp_dir, f"added_IVF{n_ivf}_Flat_nprobe_{index_ivf_added.nprobe}_{model_name}_{version}.index")
|
83 |
-
faiss.write_index(index_added, index_filepath_added)
|
84 |
-
logger.info(f"{translations['save_index']} '{index_filepath_added}'")
|
85 |
-
except Exception as e:
|
86 |
-
logger.error(f"{translations['create_index_error']}: {e}")
|
87 |
-
import traceback
|
88 |
-
logger.debug(traceback.format_exc())
|
89 |
-
|
90 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/extract.py
DELETED
@@ -1,360 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
import sys
|
4 |
-
import time
|
5 |
-
import tqdm
|
6 |
-
import torch
|
7 |
-
import shutil
|
8 |
-
import logging
|
9 |
-
import argparse
|
10 |
-
import warnings
|
11 |
-
import onnxruntime
|
12 |
-
import logging.handlers
|
13 |
-
|
14 |
-
import numpy as np
|
15 |
-
import soundfile as sf
|
16 |
-
import torch.nn.functional as F
|
17 |
-
|
18 |
-
from random import shuffle
|
19 |
-
from distutils.util import strtobool
|
20 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
21 |
-
|
22 |
-
sys.path.append(os.getcwd())
|
23 |
-
|
24 |
-
from main.configs.config import Config
|
25 |
-
from main.library.utils import check_predictors, check_embedders, load_audio, load_embedders_model
|
26 |
-
|
27 |
-
logger = logging.getLogger(__name__)
|
28 |
-
config = Config()
|
29 |
-
translations = config.translations
|
30 |
-
logger.propagate = False
|
31 |
-
|
32 |
-
warnings.filterwarnings("ignore")
|
33 |
-
for l in ["torch", "faiss", "httpx", "fairseq", "httpcore", "faiss.loader", "numba.core", "urllib3", "matplotlib"]:
|
34 |
-
logging.getLogger(l).setLevel(logging.ERROR)
|
35 |
-
|
36 |
-
def parse_arguments():
|
37 |
-
parser = argparse.ArgumentParser()
|
38 |
-
parser.add_argument("--model_name", type=str, required=True)
|
39 |
-
parser.add_argument("--rvc_version", type=str, default="v2")
|
40 |
-
parser.add_argument("--f0_method", type=str, default="rmvpe")
|
41 |
-
parser.add_argument("--pitch_guidance", type=lambda x: bool(strtobool(x)), default=True)
|
42 |
-
parser.add_argument("--hop_length", type=int, default=128)
|
43 |
-
parser.add_argument("--cpu_cores", type=int, default=2)
|
44 |
-
parser.add_argument("--gpu", type=str, default="-")
|
45 |
-
parser.add_argument("--sample_rate", type=int, required=True)
|
46 |
-
parser.add_argument("--embedder_model", type=str, default="contentvec_base")
|
47 |
-
parser.add_argument("--f0_onnx", type=lambda x: bool(strtobool(x)), default=False)
|
48 |
-
parser.add_argument("--embedders_mode", type=str, default="fairseq")
|
49 |
-
|
50 |
-
return parser.parse_args()
|
51 |
-
|
52 |
-
def generate_config(rvc_version, sample_rate, model_path):
|
53 |
-
config_save_path = os.path.join(model_path, "config.json")
|
54 |
-
if not os.path.exists(config_save_path): shutil.copy(os.path.join("main", "configs", rvc_version, f"{sample_rate}.json"), config_save_path)
|
55 |
-
|
56 |
-
def generate_filelist(pitch_guidance, model_path, rvc_version, sample_rate):
|
57 |
-
gt_wavs_dir, feature_dir = os.path.join(model_path, "sliced_audios"), os.path.join(model_path, f"{rvc_version}_extracted")
|
58 |
-
f0_dir, f0nsf_dir = None, None
|
59 |
-
|
60 |
-
if pitch_guidance: f0_dir, f0nsf_dir = os.path.join(model_path, "f0"), os.path.join(model_path, "f0_voiced")
|
61 |
-
|
62 |
-
gt_wavs_files, feature_files = set(name.split(".")[0] for name in os.listdir(gt_wavs_dir)), set(name.split(".")[0] for name in os.listdir(feature_dir))
|
63 |
-
names = gt_wavs_files & feature_files & set(name.split(".")[0] for name in os.listdir(f0_dir)) & set(name.split(".")[0] for name in os.listdir(f0nsf_dir)) if pitch_guidance else gt_wavs_files & feature_files
|
64 |
-
|
65 |
-
options = []
|
66 |
-
mute_base_path = os.path.join("assets", "logs", "mute")
|
67 |
-
|
68 |
-
for name in names:
|
69 |
-
options.append(f"{gt_wavs_dir}/{name}.wav|{feature_dir}/{name}.npy|{f0_dir}/{name}.wav.npy|{f0nsf_dir}/{name}.wav.npy|0" if pitch_guidance else f"{gt_wavs_dir}/{name}.wav|{feature_dir}/{name}.npy|0")
|
70 |
-
|
71 |
-
mute_audio_path, mute_feature_path = os.path.join(mute_base_path, "sliced_audios", f"mute{sample_rate}.wav"), os.path.join(mute_base_path, f"{rvc_version}_extracted", "mute.npy")
|
72 |
-
for _ in range(2):
|
73 |
-
options.append(f"{mute_audio_path}|{mute_feature_path}|{os.path.join(mute_base_path, 'f0', 'mute.wav.npy')}|{os.path.join(mute_base_path, 'f0_voiced', 'mute.wav.npy')}|0" if pitch_guidance else f"{mute_audio_path}|{mute_feature_path}|0")
|
74 |
-
|
75 |
-
shuffle(options)
|
76 |
-
with open(os.path.join(model_path, "filelist.txt"), "w") as f:
|
77 |
-
f.write("\n".join(options))
|
78 |
-
|
79 |
-
def setup_paths(exp_dir, version = None):
|
80 |
-
wav_path = os.path.join(exp_dir, "sliced_audios_16k")
|
81 |
-
|
82 |
-
if version:
|
83 |
-
out_path = os.path.join(exp_dir, f"{version}_extracted")
|
84 |
-
os.makedirs(out_path, exist_ok=True)
|
85 |
-
return wav_path, out_path
|
86 |
-
else:
|
87 |
-
output_root1, output_root2 = os.path.join(exp_dir, "f0"), os.path.join(exp_dir, "f0_voiced")
|
88 |
-
os.makedirs(output_root1, exist_ok=True); os.makedirs(output_root2, exist_ok=True)
|
89 |
-
return wav_path, output_root1, output_root2
|
90 |
-
|
91 |
-
def read_wave(wav_path, normalize = False, is_half = False):
|
92 |
-
wav, sr = sf.read(wav_path, dtype=np.float32)
|
93 |
-
assert sr == 16000, translations["sr_not_16000"]
|
94 |
-
|
95 |
-
feats = torch.from_numpy(wav).float()
|
96 |
-
if feats.dim() == 2: feats = feats.mean(-1)
|
97 |
-
feats = feats.view(1, -1)
|
98 |
-
|
99 |
-
if normalize: feats = F.layer_norm(feats, feats.shape)
|
100 |
-
return feats.half() if is_half else feats.float()
|
101 |
-
|
102 |
-
def get_device(gpu_index):
|
103 |
-
try:
|
104 |
-
index = int(gpu_index)
|
105 |
-
if index < torch.cuda.device_count(): return f"cuda:{index}"
|
106 |
-
else: logger.warning(translations["gpu_not_valid"])
|
107 |
-
except ValueError:
|
108 |
-
logger.warning(translations["gpu_not_valid"])
|
109 |
-
return "cpu"
|
110 |
-
|
111 |
-
def get_providers():
|
112 |
-
ort_providers = onnxruntime.get_available_providers()
|
113 |
-
|
114 |
-
if "CUDAExecutionProvider" in ort_providers: providers = ["CUDAExecutionProvider"]
|
115 |
-
elif "CoreMLExecutionProvider" in ort_providers: providers = ["CoreMLExecutionProvider"]
|
116 |
-
else: providers = ["CPUExecutionProvider"]
|
117 |
-
|
118 |
-
return providers
|
119 |
-
|
120 |
-
class FeatureInput:
|
121 |
-
def __init__(self, sample_rate=16000, hop_size=160, is_half=False, device=config.device):
|
122 |
-
self.fs = sample_rate
|
123 |
-
self.hop = hop_size
|
124 |
-
self.f0_bin = 256
|
125 |
-
self.f0_max = 1100.0
|
126 |
-
self.f0_min = 50.0
|
127 |
-
self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700)
|
128 |
-
self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700)
|
129 |
-
self.device = device
|
130 |
-
self.is_half = is_half
|
131 |
-
|
132 |
-
def compute_f0_hybrid(self, methods_str, np_arr, hop_length, f0_onnx):
|
133 |
-
methods_str = re.search("hybrid\[(.+)\]", methods_str)
|
134 |
-
if methods_str: methods = [method.strip() for method in methods_str.group(1).split("+")]
|
135 |
-
f0_computation_stack, resampled_stack = [], []
|
136 |
-
logger.debug(translations["hybrid_methods"].format(methods=methods))
|
137 |
-
|
138 |
-
for method in methods:
|
139 |
-
f0 = None
|
140 |
-
f0_methods = {"pm": lambda: self.get_pm(np_arr), "dio": lambda: self.get_pyworld(np_arr, "dio"), "mangio-crepe-full": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "full", onnx=f0_onnx), "mangio-crepe-large": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "large", onnx=f0_onnx), "mangio-crepe-medium": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "medium", onnx=f0_onnx), "mangio-crepe-small": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "small", onnx=f0_onnx), "mangio-crepe-tiny": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "tiny", onnx=f0_onnx), "crepe-full": lambda: self.get_crepe(np_arr, "full", onnx=f0_onnx), "crepe-large": lambda: self.get_crepe(np_arr, "large", onnx=f0_onnx), "crepe-medium": lambda: self.get_crepe(np_arr, "medium", onnx=f0_onnx), "crepe-small": lambda: self.get_crepe(np_arr, "small", onnx=f0_onnx), "crepe-tiny": lambda: self.get_crepe(np_arr, "tiny", onnx=f0_onnx), "fcpe": lambda: self.get_fcpe(np_arr, int(hop_length), onnx=f0_onnx), "fcpe-legacy": lambda: self.get_fcpe(np_arr, int(hop_length), legacy=True, onnx=f0_onnx), "rmvpe": lambda: self.get_rmvpe(np_arr, onnx=f0_onnx), "rmvpe-legacy": lambda: self.get_rmvpe(np_arr, legacy=True, onnx=f0_onnx), "harvest": lambda: self.get_pyworld(np_arr, "harvest"), "swipe": lambda: self.get_swipe(np_arr), "yin": lambda: self.get_yin(np_arr, int(hop_length), mode="yin"), "pyin": lambda: self.get_yin(np_arr, int(hop_length), mode="pyin")}
|
141 |
-
f0 = f0_methods.get(method, lambda: ValueError(translations["method_not_valid"]))()
|
142 |
-
f0_computation_stack.append(f0)
|
143 |
-
|
144 |
-
for f0 in f0_computation_stack:
|
145 |
-
resampled_stack.append(np.interp(np.linspace(0, len(f0), (np_arr.size // self.hop)), np.arange(len(f0)), f0))
|
146 |
-
|
147 |
-
return resampled_stack[0] if len(resampled_stack) == 1 else np.nanmedian(np.vstack(resampled_stack), axis=0)
|
148 |
-
|
149 |
-
def compute_f0(self, np_arr, f0_method, hop_length, f0_onnx=False):
|
150 |
-
f0_methods = {"pm": lambda: self.get_pm(np_arr), "dio": lambda: self.get_pyworld(np_arr, "dio"), "mangio-crepe-full": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "full", onnx=f0_onnx), "mangio-crepe-large": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "large", onnx=f0_onnx), "mangio-crepe-medium": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "medium", onnx=f0_onnx), "mangio-crepe-small": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "small", onnx=f0_onnx), "mangio-crepe-tiny": lambda: self.get_mangio_crepe(np_arr, int(hop_length), "tiny", onnx=f0_onnx), "crepe-full": lambda: self.get_crepe(np_arr, "full", onnx=f0_onnx), "crepe-large": lambda: self.get_crepe(np_arr, "large", onnx=f0_onnx), "crepe-medium": lambda: self.get_crepe(np_arr, "medium", onnx=f0_onnx), "crepe-small": lambda: self.get_crepe(np_arr, "small", onnx=f0_onnx), "crepe-tiny": lambda: self.get_crepe(np_arr, "tiny", onnx=f0_onnx), "fcpe": lambda: self.get_fcpe(np_arr, int(hop_length), onnx=f0_onnx), "fcpe-legacy": lambda: self.get_fcpe(np_arr, int(hop_length), legacy=True, onnx=f0_onnx), "rmvpe": lambda: self.get_rmvpe(np_arr, onnx=f0_onnx), "rmvpe-legacy": lambda: self.get_rmvpe(np_arr, legacy=True, onnx=f0_onnx), "harvest": lambda: self.get_pyworld(np_arr, "harvest"), "swipe": lambda: self.get_swipe(np_arr), "yin": lambda: self.get_yin(np_arr, int(hop_length), mode="yin"), "pyin": lambda: self.get_yin(np_arr, int(hop_length), mode="pyin")}
|
151 |
-
return self.compute_f0_hybrid(f0_method, np_arr, int(hop_length), f0_onnx) if "hybrid" in f0_method else f0_methods.get(f0_method, lambda: ValueError(translations["method_not_valid"]))()
|
152 |
-
|
153 |
-
def get_pm(self, x):
|
154 |
-
import parselmouth
|
155 |
-
|
156 |
-
f0 = (parselmouth.Sound(x, self.fs).to_pitch_ac(time_step=(160 / 16000 * 1000) / 1000, voicing_threshold=0.6, pitch_floor=50, pitch_ceiling=1100).selected_array["frequency"])
|
157 |
-
pad_size = ((x.size // self.hop) - len(f0) + 1) // 2
|
158 |
-
|
159 |
-
if pad_size > 0 or (x.size // self.hop) - len(f0) - pad_size > 0: f0 = np.pad(f0, [[pad_size, (x.size // self.hop) - len(f0) - pad_size]], mode="constant")
|
160 |
-
return f0
|
161 |
-
|
162 |
-
def get_mangio_crepe(self, x, hop_length, model="full", onnx=False):
|
163 |
-
from main.library.predictors.CREPE import predict
|
164 |
-
|
165 |
-
audio = torch.from_numpy(x.astype(np.float32)).to(self.device)
|
166 |
-
audio /= torch.quantile(torch.abs(audio), 0.999)
|
167 |
-
audio = audio.unsqueeze(0)
|
168 |
-
source = predict(audio, self.fs, hop_length, self.f0_min, self.f0_max, model=model, batch_size=hop_length * 2, device=self.device, pad=True, providers=get_providers(), onnx=onnx).squeeze(0).cpu().float().numpy()
|
169 |
-
source[source < 0.001] = np.nan
|
170 |
-
|
171 |
-
return np.nan_to_num(np.interp(np.arange(0, len(source) * (x.size // self.hop), len(source)) / (x.size // self.hop), np.arange(0, len(source)), source))
|
172 |
-
|
173 |
-
def get_crepe(self, x, model="full", onnx=False):
|
174 |
-
from main.library.predictors.CREPE import predict, mean, median
|
175 |
-
|
176 |
-
f0, pd = predict(torch.tensor(np.copy(x))[None].float(), self.fs, 160, self.f0_min, self.f0_max, model, batch_size=512, device=self.device, return_periodicity=True, providers=get_providers(), onnx=onnx)
|
177 |
-
f0, pd = mean(f0, 3), median(pd, 3)
|
178 |
-
f0[pd < 0.1] = 0
|
179 |
-
|
180 |
-
return f0[0].cpu().numpy()
|
181 |
-
|
182 |
-
def get_fcpe(self, x, hop_length, legacy=False, onnx=False):
|
183 |
-
from main.library.predictors.FCPE import FCPE
|
184 |
-
|
185 |
-
model_fcpe = FCPE(os.path.join("assets", "models", "predictors", ("fcpe_legacy" if legacy else"fcpe") + (".onnx" if onnx else ".pt")), hop_length=int(hop_length), f0_min=int(self.f0_min), f0_max=int(self.f0_max), dtype=torch.float32, device=self.device, sample_rate=self.fs, threshold=0.03 if legacy else 0.006, providers=get_providers(), onnx=onnx, legacy=legacy)
|
186 |
-
f0 = model_fcpe.compute_f0(x, p_len=(x.size // self.hop))
|
187 |
-
|
188 |
-
del model_fcpe
|
189 |
-
return f0
|
190 |
-
|
191 |
-
def get_rmvpe(self, x, legacy=False, onnx=False):
|
192 |
-
from main.library.predictors.RMVPE import RMVPE
|
193 |
-
|
194 |
-
rmvpe_model = RMVPE(os.path.join("assets", "models", "predictors", "rmvpe" + (".onnx" if onnx else ".pt")), is_half=self.is_half, device=self.device, onnx=onnx, providers=get_providers())
|
195 |
-
f0 = rmvpe_model.infer_from_audio_with_pitch(x, thred=0.03, f0_min=self.f0_min, f0_max=self.f0_max) if legacy else rmvpe_model.infer_from_audio(x, thred=0.03)
|
196 |
-
|
197 |
-
del rmvpe_model
|
198 |
-
return f0
|
199 |
-
|
200 |
-
def get_pyworld(self, x, model="harvest"):
|
201 |
-
from main.library.predictors.WORLD_WRAPPER import PYWORLD
|
202 |
-
|
203 |
-
pw = PYWORLD()
|
204 |
-
x = x.astype(np.double)
|
205 |
-
|
206 |
-
if model == "harvest": f0, t = pw.harvest(x, fs=self.fs, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=1000 * self.hop / self.fs)
|
207 |
-
elif model == "dio": f0, t = pw.dio(x, fs=self.fs, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=1000 * self.hop / self.fs)
|
208 |
-
else: raise ValueError(translations["method_not_valid"])
|
209 |
-
|
210 |
-
return pw.stonemask(x, self.fs, t, f0)
|
211 |
-
|
212 |
-
def get_swipe(self, x):
|
213 |
-
from main.library.predictors.SWIPE import swipe
|
214 |
-
|
215 |
-
f0, _ = swipe(x.astype(np.float32), self.fs, f0_floor=self.f0_min, f0_ceil=self.f0_max, frame_period=1000 * self.hop / self.fs)
|
216 |
-
return f0
|
217 |
-
|
218 |
-
def get_yin(self, x, hop_length, mode="yin"):
|
219 |
-
import librosa
|
220 |
-
|
221 |
-
source = np.array(librosa.yin(x.astype(np.float32), sr=self.fs, fmin=self.f0_min, fmax=self.f0_max, hop_length=hop_length) if mode == "yin" else librosa.pyin(x.astype(np.float32), fmin=self.f0_min, fmax=self.f0_max, sr=self.fs, hop_length=hop_length)[0])
|
222 |
-
source[source < 0.001] = np.nan
|
223 |
-
return np.nan_to_num(np.interp(np.arange(0, len(source) * (x.size // self.hop), len(source)) / (x.size // self.hop), np.arange(0, len(source)), source))
|
224 |
-
|
225 |
-
def coarse_f0(self, f0):
|
226 |
-
return np.rint(np.clip(((1127 * np.log(1 + f0 / 700)) - self.f0_mel_min) * (self.f0_bin - 2) / (self.f0_mel_max - self.f0_mel_min) + 1, 1, self.f0_bin - 1)).astype(int)
|
227 |
-
|
228 |
-
def process_file(self, file_info, f0_method, hop_length, f0_onnx):
|
229 |
-
inp_path, opt_path1, opt_path2, np_arr = file_info
|
230 |
-
if os.path.exists(opt_path1 + ".npy") and os.path.exists(opt_path2 + ".npy"): return
|
231 |
-
|
232 |
-
try:
|
233 |
-
feature_pit = self.compute_f0(np_arr, f0_method, hop_length, f0_onnx)
|
234 |
-
if isinstance(feature_pit, tuple): feature_pit = feature_pit[0]
|
235 |
-
np.save(opt_path2, feature_pit, allow_pickle=False)
|
236 |
-
np.save(opt_path1, self.coarse_f0(feature_pit), allow_pickle=False)
|
237 |
-
except Exception as e:
|
238 |
-
raise RuntimeError(f"{translations['extract_file_error']} {inp_path}: {e}")
|
239 |
-
|
240 |
-
def process_files(self, files, f0_method, hop_length, f0_onnx, pbar):
|
241 |
-
for file_info in files:
|
242 |
-
self.process_file(file_info, f0_method, hop_length, f0_onnx)
|
243 |
-
pbar.update()
|
244 |
-
|
245 |
-
def run_pitch_extraction(exp_dir, f0_method, hop_length, num_processes, gpus, f0_onnx, is_half):
|
246 |
-
input_root, *output_roots = setup_paths(exp_dir)
|
247 |
-
output_root1, output_root2 = output_roots if len(output_roots) == 2 else (output_roots[0], None)
|
248 |
-
|
249 |
-
paths = [(os.path.join(input_root, name), os.path.join(output_root1, name) if output_root1 else None, os.path.join(output_root2, name) if output_root2 else None, load_audio(logger, os.path.join(input_root, name), 16000)) for name in sorted(os.listdir(input_root)) if "spec" not in name]
|
250 |
-
logger.info(translations["extract_f0_method"].format(num_processes=num_processes, f0_method=f0_method))
|
251 |
-
|
252 |
-
start_time = time.time()
|
253 |
-
gpus = gpus.split("-")
|
254 |
-
process_partials = []
|
255 |
-
|
256 |
-
pbar = tqdm.tqdm(total=len(paths), ncols=100, unit="p")
|
257 |
-
for idx, gpu in enumerate(gpus):
|
258 |
-
feature_input = FeatureInput(device=get_device(gpu) if gpu != "" else "cpu", is_half=is_half)
|
259 |
-
process_partials.append((feature_input, paths[idx::len(gpus)]))
|
260 |
-
|
261 |
-
with ThreadPoolExecutor(max_workers=num_processes) as executor:
|
262 |
-
for future in as_completed([executor.submit(FeatureInput.process_files, feature_input, part_paths, f0_method, hop_length, f0_onnx, pbar) for feature_input, part_paths in process_partials]):
|
263 |
-
pbar.update(1)
|
264 |
-
logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
|
265 |
-
future.result()
|
266 |
-
|
267 |
-
pbar.close()
|
268 |
-
logger.info(translations["extract_f0_success"].format(elapsed_time=f"{(time.time() - start_time):.2f}"))
|
269 |
-
|
270 |
-
def extract_features(model, feats, version):
|
271 |
-
return torch.as_tensor(model.run([model.get_outputs()[0].name, model.get_outputs()[1].name], {"feats": feats.detach().cpu().numpy()})[0 if version == "v1" else 1], dtype=torch.float32, device=feats.device)
|
272 |
-
|
273 |
-
def process_file_embedding(file, wav_path, out_path, model, device, version, saved_cfg, embed_suffix, is_half):
|
274 |
-
out_file_path = os.path.join(out_path, file.replace("wav", "npy"))
|
275 |
-
if os.path.exists(out_file_path): return
|
276 |
-
feats = read_wave(os.path.join(wav_path, file), normalize=saved_cfg.task.normalize if saved_cfg else False, is_half=is_half).to(device)
|
277 |
-
|
278 |
-
with torch.no_grad():
|
279 |
-
if embed_suffix == ".pt":
|
280 |
-
model = model.to(device).to(torch.float16 if is_half else torch.float32).eval()
|
281 |
-
logits = model.extract_features(**{"source": feats, "padding_mask": torch.BoolTensor(feats.shape).fill_(False).to(device), "output_layer": 9 if version == "v1" else 12})
|
282 |
-
feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
|
283 |
-
elif embed_suffix == ".onnx": feats = extract_features(model, feats, version).to(device)
|
284 |
-
elif embed_suffix == ".safetensors":
|
285 |
-
model = model.to(device).to(torch.float16 if is_half else torch.float32).eval()
|
286 |
-
logits = model(feats)["last_hidden_state"]
|
287 |
-
feats = (model.final_proj(logits[0]).unsqueeze(0) if version == "v1" else logits)
|
288 |
-
else: raise ValueError(translations["option_not_valid"])
|
289 |
-
|
290 |
-
feats = feats.squeeze(0).float().cpu().numpy()
|
291 |
-
if not np.isnan(feats).any(): np.save(out_file_path, feats, allow_pickle=False)
|
292 |
-
else: logger.warning(f"{file} {translations['NaN']}")
|
293 |
-
|
294 |
-
def run_embedding_extraction(exp_dir, version, gpus, embedder_model, embedders_mode, is_half):
|
295 |
-
wav_path, out_path = setup_paths(exp_dir, version)
|
296 |
-
logger.info(translations["start_extract_hubert"])
|
297 |
-
start_time = time.time()
|
298 |
-
models, saved_cfg, embed_suffix = load_embedders_model(embedder_model, embedders_mode, providers=get_providers())
|
299 |
-
devices = [get_device(gpu) for gpu in (gpus.split("-") if gpus != "-" else ["cpu"])]
|
300 |
-
paths = sorted([file for file in os.listdir(wav_path) if file.endswith(".wav")])
|
301 |
-
|
302 |
-
if not paths:
|
303 |
-
logger.warning(translations["not_found_audio_file"])
|
304 |
-
sys.exit(1)
|
305 |
-
|
306 |
-
pbar = tqdm.tqdm(total=len(paths) * len(devices), ncols=100, unit="p")
|
307 |
-
for task in [(file, wav_path, out_path, models, device, version, saved_cfg, embed_suffix, is_half) for file in paths for device in devices]:
|
308 |
-
try:
|
309 |
-
process_file_embedding(*task)
|
310 |
-
except Exception as e:
|
311 |
-
raise RuntimeError(f"{translations['process_error']} {task[0]}: {e}")
|
312 |
-
|
313 |
-
pbar.update(1)
|
314 |
-
logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
|
315 |
-
|
316 |
-
pbar.close()
|
317 |
-
logger.info(translations["extract_hubert_success"].format(elapsed_time=f"{(time.time() - start_time):.2f}"))
|
318 |
-
|
319 |
-
def main():
|
320 |
-
args = parse_arguments()
|
321 |
-
exp_dir = os.path.join("assets", "logs", args.model_name)
|
322 |
-
f0_method, hop_length, num_processes, gpus, version, pitch_guidance, sample_rate, embedder_model, f0_onnx, embedders_mode = args.f0_method, args.hop_length, args.cpu_cores, args.gpu, args.rvc_version, args.pitch_guidance, args.sample_rate, args.embedder_model, args.f0_onnx, args.embedders_mode
|
323 |
-
|
324 |
-
check_predictors(f0_method, f0_onnx); check_embedders(embedder_model, embedders_mode)
|
325 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
326 |
-
else:
|
327 |
-
console_handler = logging.StreamHandler()
|
328 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
329 |
-
console_handler.setFormatter(console_formatter)
|
330 |
-
console_handler.setLevel(logging.INFO)
|
331 |
-
file_handler = logging.handlers.RotatingFileHandler(os.path.join(exp_dir, "extract.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
332 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
333 |
-
file_handler.setFormatter(file_formatter)
|
334 |
-
file_handler.setLevel(logging.DEBUG)
|
335 |
-
logger.addHandler(console_handler)
|
336 |
-
logger.addHandler(file_handler)
|
337 |
-
logger.setLevel(logging.DEBUG)
|
338 |
-
|
339 |
-
log_data = {translations['modelname']: args.model_name, translations['export_process']: exp_dir, translations['f0_method']: f0_method, translations['pretrain_sr']: sample_rate, translations['cpu_core']: num_processes, "Gpu": gpus, "Hop length": hop_length, translations['training_version']: version, translations['extract_f0']: pitch_guidance, translations['hubert_model']: embedder_model, translations["f0_onnx_mode"]: f0_onnx, translations["embed_mode"]: embedders_mode}
|
340 |
-
for key, value in log_data.items():
|
341 |
-
logger.debug(f"{key}: {value}")
|
342 |
-
|
343 |
-
pid_path = os.path.join(exp_dir, "extract_pid.txt")
|
344 |
-
with open(pid_path, "w") as pid_file:
|
345 |
-
pid_file.write(str(os.getpid()))
|
346 |
-
|
347 |
-
try:
|
348 |
-
run_pitch_extraction(exp_dir, f0_method, hop_length, num_processes, gpus, f0_onnx, config.is_half)
|
349 |
-
run_embedding_extraction(exp_dir, version, gpus, embedder_model, embedders_mode, config.is_half)
|
350 |
-
generate_config(version, sample_rate, exp_dir)
|
351 |
-
generate_filelist(pitch_guidance, exp_dir, version, sample_rate)
|
352 |
-
except Exception as e:
|
353 |
-
logger.error(f"{translations['extract_error']}: {e}")
|
354 |
-
import traceback
|
355 |
-
logger.debug(traceback.format_exc())
|
356 |
-
|
357 |
-
if os.path.exists(pid_path): os.remove(pid_path)
|
358 |
-
logger.info(f"{translations['extract_success']} {args.model_name}.")
|
359 |
-
|
360 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/preprocess.py
DELETED
@@ -1,270 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import time
|
4 |
-
import logging
|
5 |
-
import librosa
|
6 |
-
import argparse
|
7 |
-
import logging.handlers
|
8 |
-
|
9 |
-
import numpy as np
|
10 |
-
|
11 |
-
from tqdm import tqdm
|
12 |
-
from scipy import signal
|
13 |
-
from scipy.io import wavfile
|
14 |
-
from distutils.util import strtobool
|
15 |
-
from concurrent.futures import ProcessPoolExecutor, as_completed
|
16 |
-
|
17 |
-
sys.path.append(os.getcwd())
|
18 |
-
|
19 |
-
from main.library.utils import load_audio
|
20 |
-
from main.configs.config import Config
|
21 |
-
|
22 |
-
logger = logging.getLogger(__name__)
|
23 |
-
for l in ["numba.core.byteflow", "numba.core.ssa", "numba.core.interpreter"]:
|
24 |
-
logging.getLogger(l).setLevel(logging.ERROR)
|
25 |
-
|
26 |
-
OVERLAP, MAX_AMPLITUDE, ALPHA, HIGH_PASS_CUTOFF, SAMPLE_RATE_16K = 0.3, 0.9, 0.75, 48, 16000
|
27 |
-
|
28 |
-
config = Config()
|
29 |
-
translations = config.translations
|
30 |
-
|
31 |
-
def parse_arguments():
|
32 |
-
parser = argparse.ArgumentParser()
|
33 |
-
parser.add_argument("--model_name", type=str, required=True)
|
34 |
-
parser.add_argument("--dataset_path", type=str, default="./dataset")
|
35 |
-
parser.add_argument("--sample_rate", type=int, required=True)
|
36 |
-
parser.add_argument("--cpu_cores", type=int, default=2)
|
37 |
-
parser.add_argument("--cut_preprocess", type=lambda x: bool(strtobool(x)), default=True)
|
38 |
-
parser.add_argument("--process_effects", type=lambda x: bool(strtobool(x)), default=False)
|
39 |
-
parser.add_argument("--clean_dataset", type=lambda x: bool(strtobool(x)), default=False)
|
40 |
-
parser.add_argument("--clean_strength", type=float, default=0.7)
|
41 |
-
|
42 |
-
return parser.parse_args()
|
43 |
-
|
44 |
-
class Slicer:
|
45 |
-
def __init__(self, sr, threshold = -40.0, min_length = 5000, min_interval = 300, hop_size = 20, max_sil_kept = 5000):
|
46 |
-
if not min_length >= min_interval >= hop_size: raise ValueError(translations["min_length>=min_interval>=hop_size"])
|
47 |
-
if not max_sil_kept >= hop_size: raise ValueError(translations["max_sil_kept>=hop_size"])
|
48 |
-
|
49 |
-
min_interval = sr * min_interval / 1000
|
50 |
-
self.threshold = 10 ** (threshold / 20.0)
|
51 |
-
self.hop_size = round(sr * hop_size / 1000)
|
52 |
-
self.win_size = min(round(min_interval), 4 * self.hop_size)
|
53 |
-
self.min_length = round(sr * min_length / 1000 / self.hop_size)
|
54 |
-
self.min_interval = round(min_interval / self.hop_size)
|
55 |
-
self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
|
56 |
-
|
57 |
-
def _apply_slice(self, waveform, begin, end):
|
58 |
-
start_idx = begin * self.hop_size
|
59 |
-
|
60 |
-
if len(waveform.shape) > 1: return waveform[:, start_idx:min(waveform.shape[1], end * self.hop_size)]
|
61 |
-
else: return waveform[start_idx:min(waveform.shape[0], end * self.hop_size)]
|
62 |
-
|
63 |
-
def slice(self, waveform):
|
64 |
-
samples = waveform.mean(axis=0) if len(waveform.shape) > 1 else waveform
|
65 |
-
if samples.shape[0] <= self.min_length: return [waveform]
|
66 |
-
rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
|
67 |
-
sil_tags = []
|
68 |
-
silence_start, clip_start = None, 0
|
69 |
-
|
70 |
-
for i, rms in enumerate(rms_list):
|
71 |
-
if rms < self.threshold:
|
72 |
-
if silence_start is None: silence_start = i
|
73 |
-
continue
|
74 |
-
|
75 |
-
if silence_start is None: continue
|
76 |
-
|
77 |
-
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
|
78 |
-
need_slice_middle = (i - silence_start >= self.min_interval and i - clip_start >= self.min_length)
|
79 |
-
|
80 |
-
if not is_leading_silence and not need_slice_middle:
|
81 |
-
silence_start = None
|
82 |
-
continue
|
83 |
-
|
84 |
-
if i - silence_start <= self.max_sil_kept:
|
85 |
-
pos = rms_list[silence_start : i + 1].argmin() + silence_start
|
86 |
-
sil_tags.append((0, pos) if silence_start == 0 else (pos, pos))
|
87 |
-
clip_start = pos
|
88 |
-
elif i - silence_start <= self.max_sil_kept * 2:
|
89 |
-
pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
|
90 |
-
pos += i - self.max_sil_kept
|
91 |
-
pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
|
92 |
-
|
93 |
-
if silence_start == 0:
|
94 |
-
sil_tags.append((0, pos_r))
|
95 |
-
clip_start = pos_r
|
96 |
-
else:
|
97 |
-
sil_tags.append((min((rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start), pos), max(pos_r, pos)))
|
98 |
-
clip_start = max(pos_r, pos)
|
99 |
-
else:
|
100 |
-
pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
|
101 |
-
sil_tags.append((0, pos_r) if silence_start == 0 else ((rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start), pos_r))
|
102 |
-
clip_start = pos_r
|
103 |
-
|
104 |
-
silence_start = None
|
105 |
-
total_frames = rms_list.shape[0]
|
106 |
-
if (silence_start is not None and total_frames - silence_start >= self.min_interval): sil_tags.append((rms_list[silence_start : min(total_frames, silence_start + self.max_sil_kept) + 1].argmin() + silence_start, total_frames + 1))
|
107 |
-
|
108 |
-
if not sil_tags: return [waveform]
|
109 |
-
else:
|
110 |
-
chunks = []
|
111 |
-
if sil_tags[0][0] > 0: chunks.append(self._apply_slice(waveform, 0, sil_tags[0][0]))
|
112 |
-
|
113 |
-
for i in range(len(sil_tags) - 1):
|
114 |
-
chunks.append(self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]))
|
115 |
-
|
116 |
-
if sil_tags[-1][1] < total_frames: chunks.append(self._apply_slice(waveform, sil_tags[-1][1], total_frames))
|
117 |
-
return chunks
|
118 |
-
|
119 |
-
def get_rms(y, frame_length=2048, hop_length=512, pad_mode="constant"):
|
120 |
-
y = np.pad(y, (int(frame_length // 2), int(frame_length // 2)), mode=pad_mode)
|
121 |
-
axis = -1
|
122 |
-
x_shape_trimmed = list(y.shape)
|
123 |
-
x_shape_trimmed[axis] -= frame_length - 1
|
124 |
-
xw = np.moveaxis(np.lib.stride_tricks.as_strided(y, shape=tuple(x_shape_trimmed) + tuple([frame_length]), strides=y.strides + tuple([y.strides[axis]])), -1, axis - 1 if axis < 0 else axis + 1)
|
125 |
-
slices = [slice(None)] * xw.ndim
|
126 |
-
slices[axis] = slice(0, None, hop_length)
|
127 |
-
return np.sqrt(np.mean(np.abs(xw[tuple(slices)]) ** 2, axis=-2, keepdims=True))
|
128 |
-
|
129 |
-
class PreProcess:
|
130 |
-
def __init__(self, sr, exp_dir, per):
|
131 |
-
self.slicer = Slicer(sr=sr, threshold=-42, min_length=1500, min_interval=400, hop_size=15, max_sil_kept=500)
|
132 |
-
self.sr = sr
|
133 |
-
self.b_high, self.a_high = signal.butter(N=5, Wn=HIGH_PASS_CUTOFF, btype="high", fs=self.sr)
|
134 |
-
self.per = per
|
135 |
-
self.exp_dir = exp_dir
|
136 |
-
self.device = "cpu"
|
137 |
-
self.gt_wavs_dir = os.path.join(exp_dir, "sliced_audios")
|
138 |
-
self.wavs16k_dir = os.path.join(exp_dir, "sliced_audios_16k")
|
139 |
-
os.makedirs(self.gt_wavs_dir, exist_ok=True)
|
140 |
-
os.makedirs(self.wavs16k_dir, exist_ok=True)
|
141 |
-
|
142 |
-
def _normalize_audio(self, audio):
|
143 |
-
tmp_max = np.abs(audio).max()
|
144 |
-
if tmp_max > 2.5: return None
|
145 |
-
return (audio / tmp_max * (MAX_AMPLITUDE * ALPHA)) + (1 - ALPHA) * audio
|
146 |
-
|
147 |
-
def process_audio_segment(self, normalized_audio, sid, idx0, idx1):
|
148 |
-
if normalized_audio is None:
|
149 |
-
logger.debug(f"{sid}-{idx0}-{idx1}-filtered")
|
150 |
-
return
|
151 |
-
|
152 |
-
wavfile.write(os.path.join(self.gt_wavs_dir, f"{sid}_{idx0}_{idx1}.wav"), self.sr, normalized_audio.astype(np.float32))
|
153 |
-
wavfile.write(os.path.join(self.wavs16k_dir, f"{sid}_{idx0}_{idx1}.wav"), SAMPLE_RATE_16K, librosa.resample(normalized_audio, orig_sr=self.sr, target_sr=SAMPLE_RATE_16K, res_type="soxr_vhq").astype(np.float32))
|
154 |
-
|
155 |
-
def process_audio(self, path, idx0, sid, cut_preprocess, process_effects, clean_dataset, clean_strength):
|
156 |
-
try:
|
157 |
-
audio = load_audio(logger, path, self.sr)
|
158 |
-
|
159 |
-
if process_effects:
|
160 |
-
audio = signal.lfilter(self.b_high, self.a_high, audio)
|
161 |
-
audio = self._normalize_audio(audio)
|
162 |
-
|
163 |
-
if clean_dataset:
|
164 |
-
from main.tools.noisereduce import reduce_noise
|
165 |
-
audio = reduce_noise(y=audio, sr=self.sr, prop_decrease=clean_strength, device=config.device)
|
166 |
-
|
167 |
-
idx1 = 0
|
168 |
-
if cut_preprocess:
|
169 |
-
for audio_segment in self.slicer.slice(audio):
|
170 |
-
i = 0
|
171 |
-
|
172 |
-
while 1:
|
173 |
-
start = int(self.sr * (self.per - OVERLAP) * i)
|
174 |
-
i += 1
|
175 |
-
|
176 |
-
if len(audio_segment[start:]) > (self.per + OVERLAP) * self.sr:
|
177 |
-
self.process_audio_segment(audio_segment[start : start + int(self.per * self.sr)], sid, idx0, idx1)
|
178 |
-
idx1 += 1
|
179 |
-
else:
|
180 |
-
self.process_audio_segment(audio_segment[start:], sid, idx0, idx1)
|
181 |
-
idx1 += 1
|
182 |
-
break
|
183 |
-
else: self.process_audio_segment(audio, sid, idx0, idx1)
|
184 |
-
except Exception as e:
|
185 |
-
raise RuntimeError(f"{translations['process_audio_error']}: {e}")
|
186 |
-
|
187 |
-
def process_file(args):
|
188 |
-
pp, file, cut_preprocess, process_effects, clean_dataset, clean_strength = (args)
|
189 |
-
file_path, idx0, sid = file
|
190 |
-
pp.process_audio(file_path, idx0, sid, cut_preprocess, process_effects, clean_dataset, clean_strength)
|
191 |
-
|
192 |
-
def preprocess_training_set(input_root, sr, num_processes, exp_dir, per, cut_preprocess, process_effects, clean_dataset, clean_strength):
|
193 |
-
start_time = time.time()
|
194 |
-
|
195 |
-
pp = PreProcess(sr, exp_dir, per)
|
196 |
-
logger.info(translations["start_preprocess"].format(num_processes=num_processes))
|
197 |
-
files = []
|
198 |
-
idx = 0
|
199 |
-
|
200 |
-
for root, _, filenames in os.walk(input_root):
|
201 |
-
try:
|
202 |
-
sid = 0 if root == input_root else int(os.path.basename(root))
|
203 |
-
|
204 |
-
for f in filenames:
|
205 |
-
if f.lower().endswith(("wav", "mp3", "flac", "ogg", "opus", "m4a", "mp4", "aac", "alac", "wma", "aiff", "webm", "ac3")):
|
206 |
-
files.append((os.path.join(root, f), idx, sid))
|
207 |
-
idx += 1
|
208 |
-
except ValueError:
|
209 |
-
raise ValueError(f"{translations['not_integer']} '{os.path.basename(root)}'.")
|
210 |
-
|
211 |
-
with tqdm(total=len(files), ncols=100, unit="f") as pbar:
|
212 |
-
with ProcessPoolExecutor(max_workers=num_processes) as executor:
|
213 |
-
futures = [executor.submit(process_file, (pp, file, cut_preprocess, process_effects, clean_dataset, clean_strength)) for file in files]
|
214 |
-
for future in as_completed(futures):
|
215 |
-
try:
|
216 |
-
future.result()
|
217 |
-
except Exception as e:
|
218 |
-
raise RuntimeError(f"{translations['process_error']}: {e}")
|
219 |
-
pbar.update(1)
|
220 |
-
logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
|
221 |
-
|
222 |
-
elapsed_time = time.time() - start_time
|
223 |
-
logger.info(translations["preprocess_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
|
224 |
-
|
225 |
-
def main():
|
226 |
-
args = parse_arguments()
|
227 |
-
experiment_directory = os.path.join("assets", "logs", args.model_name)
|
228 |
-
|
229 |
-
num_processes = args.cpu_cores
|
230 |
-
num_processes = 2 if num_processes is None else int(num_processes)
|
231 |
-
|
232 |
-
dataset, sample_rate, cut_preprocess, preprocess_effects, clean_dataset, clean_strength = args.dataset_path, args.sample_rate, args.cut_preprocess, args.process_effects, args.clean_dataset, args.clean_strength
|
233 |
-
|
234 |
-
os.makedirs(experiment_directory, exist_ok=True)
|
235 |
-
|
236 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
237 |
-
else:
|
238 |
-
console_handler = logging.StreamHandler()
|
239 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
240 |
-
console_handler.setFormatter(console_formatter)
|
241 |
-
console_handler.setLevel(logging.INFO)
|
242 |
-
file_handler = logging.handlers.RotatingFileHandler(os.path.join(experiment_directory, "preprocess.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
243 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
244 |
-
file_handler.setFormatter(file_formatter)
|
245 |
-
file_handler.setLevel(logging.DEBUG)
|
246 |
-
logger.addHandler(console_handler)
|
247 |
-
logger.addHandler(file_handler)
|
248 |
-
logger.setLevel(logging.DEBUG)
|
249 |
-
|
250 |
-
log_data = {translations['modelname']: args.model_name, translations['export_process']: experiment_directory, translations['dataset_folder']: dataset, translations['pretrain_sr']: sample_rate, translations['cpu_core']: num_processes, translations['split_audio']: cut_preprocess, translations['preprocess_effect']: preprocess_effects, translations['clear_audio']: clean_dataset}
|
251 |
-
if clean_dataset: log_data[translations['clean_strength']] = clean_strength
|
252 |
-
|
253 |
-
for key, value in log_data.items():
|
254 |
-
logger.debug(f"{key}: {value}")
|
255 |
-
|
256 |
-
pid_path = os.path.join(experiment_directory, "preprocess_pid.txt")
|
257 |
-
with open(pid_path, "w") as pid_file:
|
258 |
-
pid_file.write(str(os.getpid()))
|
259 |
-
|
260 |
-
try:
|
261 |
-
preprocess_training_set(dataset, sample_rate, num_processes, experiment_directory, config.per_preprocess, cut_preprocess, preprocess_effects, clean_dataset, clean_strength)
|
262 |
-
except Exception as e:
|
263 |
-
logger.error(f"{translations['process_audio_error']} {e}")
|
264 |
-
import traceback
|
265 |
-
logger.debug(traceback.format_exc())
|
266 |
-
|
267 |
-
if os.path.exists(pid_path): os.remove(pid_path)
|
268 |
-
logger.info(f"{translations['preprocess_model_success']} {args.model_name}")
|
269 |
-
|
270 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/separator_music.py
DELETED
@@ -1,310 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import time
|
4 |
-
import logging
|
5 |
-
import argparse
|
6 |
-
import logging.handlers
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
-
|
10 |
-
from distutils.util import strtobool
|
11 |
-
|
12 |
-
sys.path.append(os.getcwd())
|
13 |
-
|
14 |
-
from main.configs.config import Config
|
15 |
-
from main.library.algorithm.separator import Separator
|
16 |
-
from main.library.utils import pydub_convert, pydub_load
|
17 |
-
|
18 |
-
config = Config()
|
19 |
-
translations = config.translations
|
20 |
-
logger = logging.getLogger(__name__)
|
21 |
-
|
22 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
23 |
-
else:
|
24 |
-
console_handler = logging.StreamHandler()
|
25 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
26 |
-
console_handler.setFormatter(console_formatter)
|
27 |
-
console_handler.setLevel(logging.INFO)
|
28 |
-
file_handler = logging.handlers.RotatingFileHandler(os.path.join("assets", "logs", "separator.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
29 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
30 |
-
file_handler.setFormatter(file_formatter)
|
31 |
-
file_handler.setLevel(logging.DEBUG)
|
32 |
-
logger.addHandler(console_handler)
|
33 |
-
logger.addHandler(file_handler)
|
34 |
-
logger.setLevel(logging.DEBUG)
|
35 |
-
|
36 |
-
demucs_models = {"HT-Tuned": "htdemucs_ft.yaml", "HT-Normal": "htdemucs.yaml", "HD_MMI": "hdemucs_mmi.yaml", "HT_6S": "htdemucs_6s.yaml"}
|
37 |
-
mdx_models = {"Main_340": "UVR-MDX-NET_Main_340.onnx", "Main_390": "UVR-MDX-NET_Main_390.onnx", "Main_406": "UVR-MDX-NET_Main_406.onnx", "Main_427": "UVR-MDX-NET_Main_427.onnx","Main_438": "UVR-MDX-NET_Main_438.onnx", "Inst_full_292": "UVR-MDX-NET-Inst_full_292.onnx", "Inst_HQ_1": "UVR-MDX-NET-Inst_HQ_1.onnx", "Inst_HQ_2": "UVR-MDX-NET-Inst_HQ_2.onnx", "Inst_HQ_3": "UVR-MDX-NET-Inst_HQ_3.onnx", "Inst_HQ_4": "UVR-MDX-NET-Inst_HQ_4.onnx", "Inst_HQ_5": "UVR-MDX-NET-Inst_HQ_5.onnx", "Kim_Vocal_1": "Kim_Vocal_1.onnx", "Kim_Vocal_2": "Kim_Vocal_2.onnx", "Kim_Inst": "Kim_Inst.onnx", "Inst_187_beta": "UVR-MDX-NET_Inst_187_beta.onnx", "Inst_82_beta": "UVR-MDX-NET_Inst_82_beta.onnx", "Inst_90_beta": "UVR-MDX-NET_Inst_90_beta.onnx", "Voc_FT": "UVR-MDX-NET-Voc_FT.onnx", "Crowd_HQ": "UVR-MDX-NET_Crowd_HQ_1.onnx", "MDXNET_9482": "UVR_MDXNET_9482.onnx", "Inst_1": "UVR-MDX-NET-Inst_1.onnx", "Inst_2": "UVR-MDX-NET-Inst_2.onnx", "Inst_3": "UVR-MDX-NET-Inst_3.onnx", "MDXNET_1_9703": "UVR_MDXNET_1_9703.onnx", "MDXNET_2_9682": "UVR_MDXNET_2_9682.onnx", "MDXNET_3_9662": "UVR_MDXNET_3_9662.onnx", "Inst_Main": "UVR-MDX-NET-Inst_Main.onnx", "MDXNET_Main": "UVR_MDXNET_Main.onnx"}
|
38 |
-
kara_models = {"Version-1": "UVR_MDXNET_KARA.onnx", "Version-2": "UVR_MDXNET_KARA_2.onnx"}
|
39 |
-
|
40 |
-
def parse_arguments():
|
41 |
-
parser = argparse.ArgumentParser()
|
42 |
-
parser.add_argument("--input_path", type=str, required=True)
|
43 |
-
parser.add_argument("--output_path", type=str, default="./audios")
|
44 |
-
parser.add_argument("--format", type=str, default="wav")
|
45 |
-
parser.add_argument("--shifts", type=int, default=2)
|
46 |
-
parser.add_argument("--segments_size", type=int, default=256)
|
47 |
-
parser.add_argument("--overlap", type=float, default=0.25)
|
48 |
-
parser.add_argument("--mdx_hop_length", type=int, default=1024)
|
49 |
-
parser.add_argument("--mdx_batch_size", type=int, default=1)
|
50 |
-
parser.add_argument("--clean_audio", type=lambda x: bool(strtobool(x)), default=False)
|
51 |
-
parser.add_argument("--clean_strength", type=float, default=0.7)
|
52 |
-
parser.add_argument("--model_name", type=str, default="HT-Normal")
|
53 |
-
parser.add_argument("--kara_model", type=str, default="Version-1")
|
54 |
-
parser.add_argument("--backing", type=lambda x: bool(strtobool(x)), default=False)
|
55 |
-
parser.add_argument("--mdx_denoise", type=lambda x: bool(strtobool(x)), default=False)
|
56 |
-
parser.add_argument("--reverb", type=lambda x: bool(strtobool(x)), default=False)
|
57 |
-
parser.add_argument("--backing_reverb", type=lambda x: bool(strtobool(x)), default=False)
|
58 |
-
parser.add_argument("--sample_rate", type=int, default=44100)
|
59 |
-
|
60 |
-
return parser.parse_args()
|
61 |
-
|
62 |
-
def main():
|
63 |
-
start_time = time.time()
|
64 |
-
pid_path = os.path.join("assets", "separate_pid.txt")
|
65 |
-
|
66 |
-
with open(pid_path, "w") as pid_file:
|
67 |
-
pid_file.write(str(os.getpid()))
|
68 |
-
|
69 |
-
try:
|
70 |
-
args = parse_arguments()
|
71 |
-
input_path, output_path, export_format, shifts, segments_size, overlap, hop_length, batch_size, clean_audio, clean_strength, model_name, kara_model, backing, mdx_denoise, reverb, backing_reverb, sample_rate = args.input_path, args.output_path, args.format, args.shifts, args.segments_size, args.overlap, args.mdx_hop_length, args.mdx_batch_size, args.clean_audio, args.clean_strength, args.model_name, args.kara_model, args.backing, args.mdx_denoise, args.reverb, args.backing_reverb, args.sample_rate
|
72 |
-
|
73 |
-
if backing_reverb and not reverb:
|
74 |
-
logger.warning(translations["turn_on_dereverb"])
|
75 |
-
sys.exit(1)
|
76 |
-
|
77 |
-
if backing_reverb and not backing:
|
78 |
-
logger.warning(translations["turn_on_separator_backing"])
|
79 |
-
sys.exit(1)
|
80 |
-
|
81 |
-
input_path = input_path.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
|
82 |
-
output_path = os.path.dirname(output_path) or output_path
|
83 |
-
|
84 |
-
log_data = {translations['audio_path']: input_path, translations['output_path']: output_path, translations['export_format']: export_format, translations['shift']: shifts, translations['segments_size']: segments_size, translations['overlap']: overlap, translations['modelname']: model_name, translations['denoise_mdx']: mdx_denoise, "Hop length": hop_length, translations['batch_size']: batch_size, translations['sr']: sample_rate}
|
85 |
-
|
86 |
-
if clean_audio:
|
87 |
-
log_data[translations['clear_audio']] = clean_audio
|
88 |
-
log_data[translations['clean_strength']] = clean_strength
|
89 |
-
|
90 |
-
if backing:
|
91 |
-
log_data[translations['backing_model_ver']] = kara_model
|
92 |
-
log_data[translations['separator_backing']] = backing
|
93 |
-
|
94 |
-
if reverb:
|
95 |
-
log_data[translations['dereveb_audio']] = reverb
|
96 |
-
log_data[translations['dereveb_backing']] = backing_reverb
|
97 |
-
|
98 |
-
for key, value in log_data.items():
|
99 |
-
logger.debug(f"{key}: {value}")
|
100 |
-
|
101 |
-
if os.path.isdir(input_path):
|
102 |
-
for f in input_path:
|
103 |
-
separation(f, output_path, export_format, shifts, overlap, segments_size, model_name, sample_rate, mdx_denoise, hop_length, batch_size, backing, reverb, kara_model, backing_reverb, clean_audio, clean_strength)
|
104 |
-
else: separation(input_path, output_path, export_format, shifts, overlap, segments_size, model_name, sample_rate, mdx_denoise, hop_length, batch_size, backing, reverb, kara_model, backing_reverb, clean_audio, clean_strength)
|
105 |
-
|
106 |
-
except Exception as e:
|
107 |
-
logger.error(f"{translations['separator_error']}: {e}")
|
108 |
-
import traceback
|
109 |
-
logger.debug(traceback.format_exc())
|
110 |
-
|
111 |
-
if os.path.exists(pid_path): os.remove(pid_path)
|
112 |
-
elapsed_time = time.time() - start_time
|
113 |
-
logger.info(translations["separator_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
|
114 |
-
|
115 |
-
def separation(input_path, output_path, export_format, shifts, overlap, segments_size, model_name, sample_rate, mdx_denoise, hop_length, batch_size, backing, reverb, kara_model, backing_reverb, clean_audio, clean_strength):
|
116 |
-
filename, _ = os.path.splitext(os.path.basename(input_path))
|
117 |
-
output_path = os.path.join(output_path, filename)
|
118 |
-
os.makedirs(output_path, exist_ok=True)
|
119 |
-
|
120 |
-
if model_name in ["HT-Tuned", "HT-Normal", "HD_MMI", "HT_6S"]: vocals, _ = separator_music_demucs(input_path, output_path, export_format, shifts, overlap, segments_size, model_name, sample_rate)
|
121 |
-
else: vocals, _ = separator_music_mdx(input_path, output_path, export_format, segments_size, overlap, mdx_denoise, model_name, hop_length, batch_size, sample_rate)
|
122 |
-
|
123 |
-
if backing: main_vocals, backing_vocals = separator_backing(vocals, output_path, export_format, segments_size, overlap, mdx_denoise, kara_model, hop_length, batch_size, sample_rate)
|
124 |
-
if reverb: vocals_no_reverb, main_vocals_no_reverb, backing_vocals_no_reverb = separator_reverb(output_path, export_format, segments_size, overlap, mdx_denoise, reverb, backing_reverb, hop_length, batch_size, sample_rate)
|
125 |
-
|
126 |
-
original_output = os.path.join(output_path, f"Original_Vocals_No_Reverb.{export_format}") if reverb else os.path.join(output_path, f"Original_Vocals.{export_format}")
|
127 |
-
main_output = os.path.join(output_path, f"Main_Vocals_No_Reverb.{export_format}") if reverb and backing_reverb else os.path.join(output_path, f"Main_Vocals.{export_format}")
|
128 |
-
backing_output = os.path.join(output_path, f"Backing_Vocals_No_Reverb.{export_format}") if reverb and backing_reverb else os.path.join(output_path, f"Backing_Vocals.{export_format}")
|
129 |
-
|
130 |
-
if clean_audio:
|
131 |
-
import soundfile as sf
|
132 |
-
|
133 |
-
logger.info(f"{translations['clear_audio']}...")
|
134 |
-
vocal_data, vocal_sr = sf.read(vocals_no_reverb if reverb else vocals, dtype=np.float32)
|
135 |
-
|
136 |
-
from main.tools.noisereduce import reduce_noise
|
137 |
-
sf.write(original_output, reduce_noise(y=vocal_data, sr=vocal_sr, prop_decrease=clean_strength), vocal_sr, format=export_format, device=config.device)
|
138 |
-
|
139 |
-
if backing:
|
140 |
-
main_data, main_sr = sf.read(main_vocals_no_reverb if reverb and backing else main_vocals, dtype=np.float32)
|
141 |
-
backing_data, backing_sr = sf.read(backing_vocals_no_reverb if reverb and backing_reverb else backing_vocals, dtype=np.float32)
|
142 |
-
|
143 |
-
sf.write(main_output, reduce_noise(y=main_data, sr=main_sr, prop_decrease=clean_strength), main_sr, format=export_format, device=config.device)
|
144 |
-
sf.write(backing_output, reduce_noise(y=backing_data, sr=backing_sr, prop_decrease=clean_strength), backing_sr, format=export_format, device=config.device)
|
145 |
-
|
146 |
-
logger.info(translations["clean_audio_success"])
|
147 |
-
|
148 |
-
def separator_music_demucs(input, output, format, shifts, overlap, segments_size, demucs_model, sample_rate):
|
149 |
-
if not os.path.exists(input):
|
150 |
-
logger.warning(translations["input_not_valid"])
|
151 |
-
sys.exit(1)
|
152 |
-
|
153 |
-
if not os.path.exists(output):
|
154 |
-
logger.warning(translations["output_not_valid"])
|
155 |
-
sys.exit(1)
|
156 |
-
|
157 |
-
for i in [f"Original_Vocals.{format}", f"Instruments.{format}"]:
|
158 |
-
if os.path.exists(os.path.join(output, i)): os.remove(os.path.join(output, i))
|
159 |
-
|
160 |
-
logger.info(f"{translations['separator_process_2']}...")
|
161 |
-
demucs_output = separator_main(audio_file=input, model_filename=demucs_models.get(demucs_model), output_format=format, output_dir=output, demucs_segment_size=(segments_size / 2), demucs_shifts=shifts, demucs_overlap=overlap, sample_rate=sample_rate)
|
162 |
-
|
163 |
-
for f in demucs_output:
|
164 |
-
path = os.path.join(output, f)
|
165 |
-
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
|
166 |
-
|
167 |
-
if '_(Drums)_' in f: drums = path
|
168 |
-
elif '_(Bass)_' in f: bass = path
|
169 |
-
elif '_(Other)_' in f: other = path
|
170 |
-
elif '_(Vocals)_' in f: os.rename(path, os.path.join(output, f"Original_Vocals.{format}"))
|
171 |
-
|
172 |
-
pydub_convert(pydub_load(drums)).overlay(pydub_convert(pydub_load(bass))).overlay(pydub_convert(pydub_load(other))).export(os.path.join(output, f"Instruments.{format}"), format=format)
|
173 |
-
|
174 |
-
for f in [drums, bass, other]:
|
175 |
-
if os.path.exists(f): os.remove(f)
|
176 |
-
|
177 |
-
logger.info(translations["separator_success_2"])
|
178 |
-
return os.path.join(output, f"Original_Vocals.{format}"), os.path.join(output, f"Instruments.{format}")
|
179 |
-
|
180 |
-
def separator_backing(input, output, format, segments_size, overlap, denoise, kara_model, hop_length, batch_size, sample_rate):
|
181 |
-
if not os.path.exists(input):
|
182 |
-
logger.warning(translations["input_not_valid"])
|
183 |
-
sys.exit(1)
|
184 |
-
|
185 |
-
if not os.path.exists(output):
|
186 |
-
logger.warning(translations["output_not_valid"])
|
187 |
-
sys.exit(1)
|
188 |
-
|
189 |
-
for f in [f"Main_Vocals.{format}", f"Backing_Vocals.{format}"]:
|
190 |
-
if os.path.exists(os.path.join(output, f)): os.remove(os.path.join(output, f))
|
191 |
-
|
192 |
-
model_2 = kara_models.get(kara_model)
|
193 |
-
logger.info(f"{translations['separator_process_backing']}...")
|
194 |
-
|
195 |
-
backing_outputs = separator_main(audio_file=input, model_filename=model_2, output_format=format, output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise, sample_rate=sample_rate)
|
196 |
-
main_output = os.path.join(output, f"Main_Vocals.{format}")
|
197 |
-
backing_output = os.path.join(output, f"Backing_Vocals.{format}")
|
198 |
-
|
199 |
-
for f in backing_outputs:
|
200 |
-
path = os.path.join(output, f)
|
201 |
-
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
|
202 |
-
|
203 |
-
if '_(Instrumental)_' in f: os.rename(path, backing_output)
|
204 |
-
elif '_(Vocals)_' in f: os.rename(path, main_output)
|
205 |
-
|
206 |
-
logger.info(translations["separator_process_backing_success"])
|
207 |
-
return main_output, backing_output
|
208 |
-
|
209 |
-
def separator_music_mdx(input, output, format, segments_size, overlap, denoise, mdx_model, hop_length, batch_size, sample_rate):
|
210 |
-
if not os.path.exists(input):
|
211 |
-
logger.warning(translations["input_not_valid"])
|
212 |
-
sys.exit(1)
|
213 |
-
|
214 |
-
if not os.path.exists(output):
|
215 |
-
logger.warning(translations["output_not_valid"])
|
216 |
-
sys.exit(1)
|
217 |
-
|
218 |
-
for i in [f"Original_Vocals.{format}", f"Instruments.{format}"]:
|
219 |
-
if os.path.exists(os.path.join(output, i)): os.remove(os.path.join(output, i))
|
220 |
-
|
221 |
-
model_3 = mdx_models.get(mdx_model)
|
222 |
-
logger.info(f"{translations['separator_process_2']}...")
|
223 |
-
|
224 |
-
output_music = separator_main(audio_file=input, model_filename=model_3, output_format=format, output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise, sample_rate=sample_rate)
|
225 |
-
original_output, instruments_output = os.path.join(output, f"Original_Vocals.{format}"), os.path.join(output, f"Instruments.{format}")
|
226 |
-
|
227 |
-
for f in output_music:
|
228 |
-
path = os.path.join(output, f)
|
229 |
-
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
|
230 |
-
|
231 |
-
if '_(Instrumental)_' in f: os.rename(path, instruments_output)
|
232 |
-
elif '_(Vocals)_' in f: os.rename(path, original_output)
|
233 |
-
|
234 |
-
logger.info(translations["separator_process_backing_success"])
|
235 |
-
return original_output, instruments_output
|
236 |
-
|
237 |
-
def separator_reverb(output, format, segments_size, overlap, denoise, original, backing_reverb, hop_length, batch_size, sample_rate):
|
238 |
-
if not os.path.exists(output):
|
239 |
-
logger.warning(translations["output_not_valid"])
|
240 |
-
sys.exit(1)
|
241 |
-
|
242 |
-
for i in [f"Original_Vocals_Reverb.{format}", f"Main_Vocals_Reverb.{format}", f"Original_Vocals_No_Reverb.{format}", f"Main_Vocals_No_Reverb.{format}"]:
|
243 |
-
if os.path.exists(os.path.join(output, i)): os.remove(os.path.join(output, i))
|
244 |
-
|
245 |
-
dereveb_path = []
|
246 |
-
|
247 |
-
if original:
|
248 |
-
try:
|
249 |
-
dereveb_path.append(os.path.join(output, [f for f in os.listdir(output) if 'Original_Vocals' in f][0]))
|
250 |
-
except IndexError:
|
251 |
-
logger.warning(translations["not_found_original_vocal"])
|
252 |
-
sys.exit(1)
|
253 |
-
|
254 |
-
if backing_reverb:
|
255 |
-
try:
|
256 |
-
dereveb_path.append(os.path.join(output, [f for f in os.listdir(output) if 'Main_Vocals' in f][0]))
|
257 |
-
except IndexError:
|
258 |
-
logger.warning(translations["not_found_main_vocal"])
|
259 |
-
sys.exit(1)
|
260 |
-
|
261 |
-
if backing_reverb:
|
262 |
-
try:
|
263 |
-
dereveb_path.append(os.path.join(output, [f for f in os.listdir(output) if 'Backing_Vocals' in f][0]))
|
264 |
-
except IndexError:
|
265 |
-
logger.warning(translations["not_found_backing_vocal"])
|
266 |
-
sys.exit(1)
|
267 |
-
|
268 |
-
for path in dereveb_path:
|
269 |
-
if not os.path.exists(path):
|
270 |
-
logger.warning(translations["not_found"].format(name=path))
|
271 |
-
sys.exit(1)
|
272 |
-
|
273 |
-
if "Original_Vocals" in path:
|
274 |
-
reverb_path, no_reverb_path = os.path.join(output, f"Original_Vocals_Reverb.{format}"), os.path.join(output, f"Original_Vocals_No_Reverb.{format}")
|
275 |
-
start_title, end_title = translations["process_original"], translations["process_original_success"]
|
276 |
-
elif "Main_Vocals" in path:
|
277 |
-
reverb_path, no_reverb_path = os.path.join(output, f"Main_Vocals_Reverb.{format}"), os.path.join(output, f"Main_Vocals_No_Reverb.{format}")
|
278 |
-
start_title, end_title = translations["process_main"], translations["process_main_success"]
|
279 |
-
elif "Backing_Vocals" in path:
|
280 |
-
reverb_path, no_reverb_path = os.path.join(output, f"Backing_Vocals_Reverb.{format}"), os.path.join(output, f"Backing_Vocals_No_Reverb.{format}")
|
281 |
-
start_title, end_title = translations["process_backing"], translations["process_backing_success"]
|
282 |
-
|
283 |
-
logger.info(start_title)
|
284 |
-
output_dereveb = separator_main(audio_file=path, model_filename="Reverb_HQ_By_FoxJoy.onnx", output_format=format, output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise, sample_rate=sample_rate)
|
285 |
-
|
286 |
-
for f in output_dereveb:
|
287 |
-
path = os.path.join(output, f)
|
288 |
-
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
|
289 |
-
|
290 |
-
if '_(Reverb)_' in f: os.rename(path, reverb_path)
|
291 |
-
elif '_(No Reverb)_' in f: os.rename(path, no_reverb_path)
|
292 |
-
|
293 |
-
logger.info(end_title)
|
294 |
-
|
295 |
-
return (os.path.join(output, f"Original_Vocals_No_Reverb.{format}") if original else None), (os.path.join(output, f"Main_Vocals_No_Reverb.{format}") if backing_reverb else None), (os.path.join(output, f"Backing_Vocals_No_Reverb.{format}") if backing_reverb else None)
|
296 |
-
|
297 |
-
def separator_main(audio_file=None, model_filename="UVR-MDX-NET_Main_340.onnx", output_format="wav", output_dir=".", mdx_segment_size=256, mdx_overlap=0.25, mdx_batch_size=1, mdx_hop_length=1024, mdx_enable_denoise=True, demucs_segment_size=256, demucs_shifts=2, demucs_overlap=0.25, sample_rate=44100):
|
298 |
-
try:
|
299 |
-
separator = Separator(logger=logger, log_formatter=file_formatter, log_level=logging.INFO, output_dir=output_dir, output_format=output_format, output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=sample_rate, mdx_params={"hop_length": mdx_hop_length, "segment_size": mdx_segment_size, "overlap": mdx_overlap, "batch_size": mdx_batch_size, "enable_denoise": mdx_enable_denoise}, demucs_params={"segment_size": demucs_segment_size, "shifts": demucs_shifts, "overlap": demucs_overlap, "segments_enabled": True})
|
300 |
-
separator.load_model(model_filename=model_filename)
|
301 |
-
|
302 |
-
return separator.separate(audio_file)
|
303 |
-
except:
|
304 |
-
logger.debug(translations["default_setting"])
|
305 |
-
separator = Separator(logger=logger, log_formatter=file_formatter, log_level=logging.INFO, output_dir=output_dir, output_format=output_format, output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=44100, mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": mdx_enable_denoise}, demucs_params={"segment_size": 128, "shifts": 2, "overlap": 0.25, "segments_enabled": True})
|
306 |
-
separator.load_model(model_filename=model_filename)
|
307 |
-
|
308 |
-
return separator.separate(audio_file)
|
309 |
-
|
310 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/train.py
DELETED
@@ -1,990 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import glob
|
4 |
-
import json
|
5 |
-
import torch
|
6 |
-
import hashlib
|
7 |
-
import logging
|
8 |
-
import argparse
|
9 |
-
import datetime
|
10 |
-
import warnings
|
11 |
-
import logging.handlers
|
12 |
-
|
13 |
-
import numpy as np
|
14 |
-
import soundfile as sf
|
15 |
-
import matplotlib.pyplot as plt
|
16 |
-
import torch.distributed as dist
|
17 |
-
import torch.utils.data as tdata
|
18 |
-
import torch.multiprocessing as mp
|
19 |
-
|
20 |
-
from tqdm import tqdm
|
21 |
-
from collections import OrderedDict
|
22 |
-
from random import randint, shuffle
|
23 |
-
from torch.utils.checkpoint import checkpoint
|
24 |
-
from torch.cuda.amp import GradScaler, autocast
|
25 |
-
from torch.utils.tensorboard import SummaryWriter
|
26 |
-
|
27 |
-
from time import time as ttime
|
28 |
-
from torch.nn import functional as F
|
29 |
-
from distutils.util import strtobool
|
30 |
-
from librosa.filters import mel as librosa_mel_fn
|
31 |
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
32 |
-
from torch.nn.utils.parametrizations import spectral_norm, weight_norm
|
33 |
-
|
34 |
-
sys.path.append(os.getcwd())
|
35 |
-
|
36 |
-
from main.configs.config import Config
|
37 |
-
from main.library.algorithm.residuals import LRELU_SLOPE
|
38 |
-
from main.library.algorithm.synthesizers import Synthesizer
|
39 |
-
from main.library.algorithm.commons import get_padding, slice_segments, clip_grad_value
|
40 |
-
|
41 |
-
MATPLOTLIB_FLAG = False
|
42 |
-
main_config = Config()
|
43 |
-
translations = main_config.translations
|
44 |
-
warnings.filterwarnings("ignore")
|
45 |
-
logging.getLogger("torch").setLevel(logging.ERROR)
|
46 |
-
|
47 |
-
class HParams:
|
48 |
-
def __init__(self, **kwargs):
|
49 |
-
for k, v in kwargs.items():
|
50 |
-
self[k] = HParams(**v) if isinstance(v, dict) else v
|
51 |
-
|
52 |
-
def keys(self):
|
53 |
-
return self.__dict__.keys()
|
54 |
-
|
55 |
-
def items(self):
|
56 |
-
return self.__dict__.items()
|
57 |
-
|
58 |
-
def values(self):
|
59 |
-
return self.__dict__.values()
|
60 |
-
|
61 |
-
def __len__(self):
|
62 |
-
return len(self.__dict__)
|
63 |
-
|
64 |
-
def __getitem__(self, key):
|
65 |
-
return self.__dict__[key]
|
66 |
-
|
67 |
-
def __setitem__(self, key, value):
|
68 |
-
self.__dict__[key] = value
|
69 |
-
|
70 |
-
def __contains__(self, key):
|
71 |
-
return key in self.__dict__
|
72 |
-
|
73 |
-
def __repr__(self):
|
74 |
-
return repr(self.__dict__)
|
75 |
-
|
76 |
-
def parse_arguments():
|
77 |
-
parser = argparse.ArgumentParser()
|
78 |
-
parser.add_argument("--model_name", type=str, required=True)
|
79 |
-
parser.add_argument("--rvc_version", type=str, default="v2")
|
80 |
-
parser.add_argument("--save_every_epoch", type=int, required=True)
|
81 |
-
parser.add_argument("--save_only_latest", type=lambda x: bool(strtobool(x)), default=True)
|
82 |
-
parser.add_argument("--save_every_weights", type=lambda x: bool(strtobool(x)), default=True)
|
83 |
-
parser.add_argument("--total_epoch", type=int, default=300)
|
84 |
-
parser.add_argument("--sample_rate", type=int, required=True)
|
85 |
-
parser.add_argument("--batch_size", type=int, default=8)
|
86 |
-
parser.add_argument("--gpu", type=str, default="0")
|
87 |
-
parser.add_argument("--pitch_guidance", type=lambda x: bool(strtobool(x)), default=True)
|
88 |
-
parser.add_argument("--g_pretrained_path", type=str, default="")
|
89 |
-
parser.add_argument("--d_pretrained_path", type=str, default="")
|
90 |
-
parser.add_argument("--overtraining_detector", type=lambda x: bool(strtobool(x)), default=False)
|
91 |
-
parser.add_argument("--overtraining_threshold", type=int, default=50)
|
92 |
-
parser.add_argument("--cleanup", type=lambda x: bool(strtobool(x)), default=False)
|
93 |
-
parser.add_argument("--cache_data_in_gpu", type=lambda x: bool(strtobool(x)), default=False)
|
94 |
-
parser.add_argument("--model_author", type=str)
|
95 |
-
parser.add_argument("--vocoder", type=str, default="Default")
|
96 |
-
parser.add_argument("--checkpointing", type=lambda x: bool(strtobool(x)), default=False)
|
97 |
-
parser.add_argument("--deterministic", type=lambda x: bool(strtobool(x)), default=False)
|
98 |
-
parser.add_argument("--benchmark", type=lambda x: bool(strtobool(x)), default=False)
|
99 |
-
|
100 |
-
return parser.parse_args()
|
101 |
-
|
102 |
-
args = parse_arguments()
|
103 |
-
model_name, save_every_epoch, total_epoch, pretrainG, pretrainD, version, gpus, batch_size, sample_rate, pitch_guidance, save_only_latest, save_every_weights, cache_data_in_gpu, overtraining_detector, overtraining_threshold, cleanup, model_author, vocoder, checkpointing = args.model_name, args.save_every_epoch, args.total_epoch, args.g_pretrained_path, args.d_pretrained_path, args.rvc_version, args.gpu, args.batch_size, args.sample_rate, args.pitch_guidance, args.save_only_latest, args.save_every_weights, args.cache_data_in_gpu, args.overtraining_detector, args.overtraining_threshold, args.cleanup, args.model_author, args.vocoder, args.checkpointing
|
104 |
-
|
105 |
-
experiment_dir = os.path.join("assets", "logs", model_name)
|
106 |
-
training_file_path = os.path.join(experiment_dir, "training_data.json")
|
107 |
-
config_save_path = os.path.join(experiment_dir, "config.json")
|
108 |
-
torch.backends.cudnn.deterministic = args.deterministic
|
109 |
-
torch.backends.cudnn.benchmark = args.benchmark
|
110 |
-
lowest_value = {"step": 0, "value": float("inf"), "epoch": 0}
|
111 |
-
global_step, last_loss_gen_all, overtrain_save_epoch = 0, 0, 0
|
112 |
-
loss_gen_history, smoothed_loss_gen_history, loss_disc_history, smoothed_loss_disc_history = [], [], [], []
|
113 |
-
|
114 |
-
with open(config_save_path, "r") as f:
|
115 |
-
config = json.load(f)
|
116 |
-
|
117 |
-
config = HParams(**config)
|
118 |
-
config.data.training_files = os.path.join(experiment_dir, "filelist.txt")
|
119 |
-
logger = logging.getLogger(__name__)
|
120 |
-
|
121 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
122 |
-
else:
|
123 |
-
console_handler = logging.StreamHandler()
|
124 |
-
console_handler.setFormatter(logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
|
125 |
-
console_handler.setLevel(logging.INFO)
|
126 |
-
file_handler = logging.handlers.RotatingFileHandler(os.path.join(experiment_dir, "train.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
127 |
-
file_handler.setFormatter(logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
|
128 |
-
file_handler.setLevel(logging.DEBUG)
|
129 |
-
logger.addHandler(console_handler)
|
130 |
-
logger.addHandler(file_handler)
|
131 |
-
logger.setLevel(logging.DEBUG)
|
132 |
-
|
133 |
-
log_data = {translations['modelname']: model_name, translations["save_every_epoch"]: save_every_epoch, translations["total_e"]: total_epoch, translations["dorg"].format(pretrainG=pretrainG, pretrainD=pretrainD): "", translations['training_version']: version, "Gpu": gpus, translations['batch_size']: batch_size, translations['pretrain_sr']: sample_rate, translations['training_f0']: pitch_guidance, translations['save_only_latest']: save_only_latest, translations['save_every_weights']: save_every_weights, translations['cache_in_gpu']: cache_data_in_gpu, translations['overtraining_detector']: overtraining_detector, translations['threshold']: overtraining_threshold, translations['cleanup_training']: cleanup, translations['memory_efficient_training']: checkpointing}
|
134 |
-
if model_author: log_data[translations["model_author"].format(model_author=model_author)] = ""
|
135 |
-
if vocoder != "Default": log_data[translations['vocoder']] = vocoder
|
136 |
-
|
137 |
-
for key, value in log_data.items():
|
138 |
-
logger.debug(f"{key}: {value}" if value != "" else f"{key} {value}")
|
139 |
-
|
140 |
-
def main():
|
141 |
-
global training_file_path, last_loss_gen_all, smoothed_loss_gen_history, loss_gen_history, loss_disc_history, smoothed_loss_disc_history, overtrain_save_epoch, model_author, vocoder, checkpointing, gpus
|
142 |
-
|
143 |
-
try:
|
144 |
-
os.environ["MASTER_ADDR"] = "localhost"
|
145 |
-
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
|
146 |
-
|
147 |
-
if torch.cuda.is_available():
|
148 |
-
device, gpus = torch.device("cuda"), [int(item) for item in gpus.split("-")]
|
149 |
-
n_gpus = len(gpus)
|
150 |
-
elif torch.backends.mps.is_available():
|
151 |
-
device, gpus = torch.device("mps"), [0]
|
152 |
-
n_gpus = 1
|
153 |
-
else:
|
154 |
-
device, gpus = torch.device("cpu"), [0]
|
155 |
-
n_gpus = 1
|
156 |
-
logger.warning(translations["not_gpu"])
|
157 |
-
|
158 |
-
def start():
|
159 |
-
children = []
|
160 |
-
pid_data = {"process_pids": []}
|
161 |
-
|
162 |
-
with open(config_save_path, "r") as pid_file:
|
163 |
-
try:
|
164 |
-
pid_data.update(json.load(pid_file))
|
165 |
-
except json.JSONDecodeError:
|
166 |
-
pass
|
167 |
-
|
168 |
-
with open(config_save_path, "w") as pid_file:
|
169 |
-
for rank, device_id in enumerate(gpus):
|
170 |
-
subproc = mp.Process(target=run, args=(rank, n_gpus, experiment_dir, pretrainG, pretrainD, pitch_guidance, total_epoch, save_every_weights, config, device, device_id, model_author, vocoder, checkpointing))
|
171 |
-
children.append(subproc)
|
172 |
-
subproc.start()
|
173 |
-
pid_data["process_pids"].append(subproc.pid)
|
174 |
-
|
175 |
-
json.dump(pid_data, pid_file, indent=4)
|
176 |
-
|
177 |
-
for i in range(n_gpus):
|
178 |
-
children[i].join()
|
179 |
-
|
180 |
-
def load_from_json(file_path):
|
181 |
-
if os.path.exists(file_path):
|
182 |
-
with open(file_path, "r") as f:
|
183 |
-
data = json.load(f)
|
184 |
-
return (data.get("loss_disc_history", []), data.get("smoothed_loss_disc_history", []), data.get("loss_gen_history", []), data.get("smoothed_loss_gen_history", []))
|
185 |
-
return [], [], [], []
|
186 |
-
|
187 |
-
def continue_overtrain_detector(training_file_path):
|
188 |
-
if overtraining_detector and os.path.exists(training_file_path): (loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history) = load_from_json(training_file_path)
|
189 |
-
|
190 |
-
if cleanup:
|
191 |
-
for root, dirs, files in os.walk(experiment_dir, topdown=False):
|
192 |
-
for name in files:
|
193 |
-
file_path = os.path.join(root, name)
|
194 |
-
_, file_extension = os.path.splitext(name)
|
195 |
-
if (file_extension == ".0" or (name.startswith("D_") and file_extension == ".pth") or (name.startswith("G_") and file_extension == ".pth") or (file_extension == ".index")): os.remove(file_path)
|
196 |
-
|
197 |
-
for name in dirs:
|
198 |
-
if name == "eval":
|
199 |
-
folder_path = os.path.join(root, name)
|
200 |
-
for item in os.listdir(folder_path):
|
201 |
-
item_path = os.path.join(folder_path, item)
|
202 |
-
if os.path.isfile(item_path): os.remove(item_path)
|
203 |
-
os.rmdir(folder_path)
|
204 |
-
|
205 |
-
continue_overtrain_detector(training_file_path)
|
206 |
-
start()
|
207 |
-
except Exception as e:
|
208 |
-
logger.error(f"{translations['training_error']} {e}")
|
209 |
-
import traceback
|
210 |
-
logger.debug(traceback.format_exc())
|
211 |
-
|
212 |
-
def plot_spectrogram_to_numpy(spectrogram):
|
213 |
-
global MATPLOTLIB_FLAG
|
214 |
-
|
215 |
-
if not MATPLOTLIB_FLAG:
|
216 |
-
plt.switch_backend("Agg")
|
217 |
-
MATPLOTLIB_FLAG = True
|
218 |
-
|
219 |
-
fig, ax = plt.subplots(figsize=(10, 2))
|
220 |
-
plt.colorbar(ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none"), ax=ax)
|
221 |
-
plt.xlabel("Frames")
|
222 |
-
plt.ylabel("Channels")
|
223 |
-
plt.tight_layout()
|
224 |
-
fig.canvas.draw()
|
225 |
-
plt.close(fig)
|
226 |
-
|
227 |
-
try:
|
228 |
-
data = np.array(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8).reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3]
|
229 |
-
except:
|
230 |
-
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="").reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
231 |
-
|
232 |
-
return data
|
233 |
-
|
234 |
-
def verify_checkpoint_shapes(checkpoint_path, model):
|
235 |
-
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
236 |
-
checkpoint_state_dict = checkpoint["model"]
|
237 |
-
try:
|
238 |
-
model_state_dict = model.module.load_state_dict(checkpoint_state_dict) if hasattr(model, "module") else model.load_state_dict(checkpoint_state_dict)
|
239 |
-
except RuntimeError:
|
240 |
-
logger.warning(translations["checkpointing_err"])
|
241 |
-
sys.exit(1)
|
242 |
-
else: del checkpoint, checkpoint_state_dict, model_state_dict
|
243 |
-
|
244 |
-
def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sample_rate=22050):
|
245 |
-
for k, v in scalars.items():
|
246 |
-
writer.add_scalar(k, v, global_step)
|
247 |
-
|
248 |
-
for k, v in histograms.items():
|
249 |
-
writer.add_histogram(k, v, global_step)
|
250 |
-
|
251 |
-
for k, v in images.items():
|
252 |
-
writer.add_image(k, v, global_step, dataformats="HWC")
|
253 |
-
|
254 |
-
for k, v in audios.items():
|
255 |
-
writer.add_audio(k, v, global_step, audio_sample_rate)
|
256 |
-
|
257 |
-
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
|
258 |
-
assert os.path.isfile(checkpoint_path), translations["not_found_checkpoint"].format(checkpoint_path=checkpoint_path)
|
259 |
-
checkpoint_dict = replace_keys_in_dict(replace_keys_in_dict(torch.load(checkpoint_path, map_location="cpu"), ".weight_v", ".parametrizations.weight.original1"), ".weight_g", ".parametrizations.weight.original0")
|
260 |
-
new_state_dict = {k: checkpoint_dict["model"].get(k, v) for k, v in (model.module.state_dict() if hasattr(model, "module") else model.state_dict()).items()}
|
261 |
-
|
262 |
-
if hasattr(model, "module"): model.module.load_state_dict(new_state_dict, strict=False)
|
263 |
-
else: model.load_state_dict(new_state_dict, strict=False)
|
264 |
-
|
265 |
-
if optimizer and load_opt == 1: optimizer.load_state_dict(checkpoint_dict.get("optimizer", {}))
|
266 |
-
logger.debug(translations["save_checkpoint"].format(checkpoint_path=checkpoint_path, checkpoint_dict=checkpoint_dict['iteration']))
|
267 |
-
return (model, optimizer, checkpoint_dict.get("learning_rate", 0), checkpoint_dict["iteration"])
|
268 |
-
|
269 |
-
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
|
270 |
-
state_dict = (model.module.state_dict() if hasattr(model, "module") else model.state_dict())
|
271 |
-
torch.save(replace_keys_in_dict(replace_keys_in_dict({"model": state_dict, "iteration": iteration, "optimizer": optimizer.state_dict(), "learning_rate": learning_rate}, ".parametrizations.weight.original1", ".weight_v"), ".parametrizations.weight.original0", ".weight_g"), checkpoint_path)
|
272 |
-
logger.info(translations["save_model"].format(checkpoint_path=checkpoint_path, iteration=iteration))
|
273 |
-
|
274 |
-
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
|
275 |
-
checkpoints = sorted(glob.glob(os.path.join(dir_path, regex)), key=lambda f: int("".join(filter(str.isdigit, f))))
|
276 |
-
return checkpoints[-1] if checkpoints else None
|
277 |
-
|
278 |
-
def load_wav_to_torch(full_path):
|
279 |
-
data, sample_rate = sf.read(full_path, dtype=np.float32)
|
280 |
-
return torch.FloatTensor(data.astype(np.float32)), sample_rate
|
281 |
-
|
282 |
-
def load_filepaths_and_text(filename, split="|"):
|
283 |
-
with open(filename, encoding="utf-8") as f:
|
284 |
-
return [line.strip().split(split) for line in f]
|
285 |
-
|
286 |
-
def feature_loss(fmap_r, fmap_g):
|
287 |
-
loss = 0
|
288 |
-
for dr, dg in zip(fmap_r, fmap_g):
|
289 |
-
for rl, gl in zip(dr, dg):
|
290 |
-
loss += torch.mean(torch.abs(rl.float().detach() - gl.float()))
|
291 |
-
return loss * 2
|
292 |
-
|
293 |
-
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
294 |
-
loss = 0
|
295 |
-
r_losses, g_losses = [], []
|
296 |
-
|
297 |
-
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
298 |
-
dr = dr.float()
|
299 |
-
dg = dg.float()
|
300 |
-
r_loss = torch.mean((1 - dr) ** 2)
|
301 |
-
g_loss = torch.mean(dg**2)
|
302 |
-
loss += r_loss + g_loss
|
303 |
-
r_losses.append(r_loss.item())
|
304 |
-
g_losses.append(g_loss.item())
|
305 |
-
return loss, r_losses, g_losses
|
306 |
-
|
307 |
-
def generator_loss(disc_outputs):
|
308 |
-
loss = 0
|
309 |
-
gen_losses = []
|
310 |
-
|
311 |
-
for dg in disc_outputs:
|
312 |
-
l = torch.mean((1 - dg.float()) ** 2)
|
313 |
-
gen_losses.append(l)
|
314 |
-
loss += l
|
315 |
-
return loss, gen_losses
|
316 |
-
|
317 |
-
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
318 |
-
z_p = z_p.float()
|
319 |
-
logs_q = logs_q.float()
|
320 |
-
m_p = m_p.float()
|
321 |
-
logs_p = logs_p.float()
|
322 |
-
z_mask = z_mask.float()
|
323 |
-
kl = logs_p - logs_q - 0.5
|
324 |
-
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
|
325 |
-
return torch.sum(kl * z_mask) / torch.sum(z_mask)
|
326 |
-
|
327 |
-
class TextAudioLoaderMultiNSFsid(tdata.Dataset):
|
328 |
-
def __init__(self, hparams):
|
329 |
-
self.audiopaths_and_text = load_filepaths_and_text(hparams.training_files)
|
330 |
-
self.max_wav_value = hparams.max_wav_value
|
331 |
-
self.sample_rate = hparams.sample_rate
|
332 |
-
self.filter_length = hparams.filter_length
|
333 |
-
self.hop_length = hparams.hop_length
|
334 |
-
self.win_length = hparams.win_length
|
335 |
-
self.sample_rate = hparams.sample_rate
|
336 |
-
self.min_text_len = getattr(hparams, "min_text_len", 1)
|
337 |
-
self.max_text_len = getattr(hparams, "max_text_len", 5000)
|
338 |
-
self._filter()
|
339 |
-
|
340 |
-
def _filter(self):
|
341 |
-
audiopaths_and_text_new, lengths = [], []
|
342 |
-
for audiopath, text, pitch, pitchf, dv in self.audiopaths_and_text:
|
343 |
-
if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
|
344 |
-
audiopaths_and_text_new.append([audiopath, text, pitch, pitchf, dv])
|
345 |
-
lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
|
346 |
-
|
347 |
-
self.audiopaths_and_text = audiopaths_and_text_new
|
348 |
-
self.lengths = lengths
|
349 |
-
|
350 |
-
def get_sid(self, sid):
|
351 |
-
try:
|
352 |
-
sid = torch.LongTensor([int(sid)])
|
353 |
-
except ValueError as e:
|
354 |
-
logger.error(translations["sid_error"].format(sid=sid, e=e))
|
355 |
-
sid = torch.LongTensor([0])
|
356 |
-
return sid
|
357 |
-
|
358 |
-
def get_audio_text_pair(self, audiopath_and_text):
|
359 |
-
phone, pitch, pitchf = self.get_labels(audiopath_and_text[1], audiopath_and_text[2], audiopath_and_text[3])
|
360 |
-
spec, wav = self.get_audio(audiopath_and_text[0])
|
361 |
-
dv = self.get_sid(audiopath_and_text[4])
|
362 |
-
len_phone = phone.size()[0]
|
363 |
-
len_spec = spec.size()[-1]
|
364 |
-
|
365 |
-
if len_phone != len_spec:
|
366 |
-
len_min = min(len_phone, len_spec)
|
367 |
-
len_wav = len_min * self.hop_length
|
368 |
-
spec, wav, phone = spec[:, :len_min], wav[:, :len_wav], phone[:len_min, :]
|
369 |
-
pitch, pitchf = pitch[:len_min], pitchf[:len_min]
|
370 |
-
return (spec, wav, phone, pitch, pitchf, dv)
|
371 |
-
|
372 |
-
def get_labels(self, phone, pitch, pitchf):
|
373 |
-
phone = np.repeat(np.load(phone), 2, axis=0)
|
374 |
-
n_num = min(phone.shape[0], 900)
|
375 |
-
return torch.FloatTensor(phone[:n_num, :]), torch.LongTensor(np.load(pitch)[:n_num]), torch.FloatTensor(np.load(pitchf)[:n_num])
|
376 |
-
|
377 |
-
def get_audio(self, filename):
|
378 |
-
audio, sample_rate = load_wav_to_torch(filename)
|
379 |
-
if sample_rate != self.sample_rate: raise ValueError(translations["sr_does_not_match"].format(sample_rate=sample_rate, sample_rate2=self.sample_rate))
|
380 |
-
audio_norm = audio.unsqueeze(0)
|
381 |
-
spec_filename = filename.replace(".wav", ".spec.pt")
|
382 |
-
|
383 |
-
if os.path.exists(spec_filename):
|
384 |
-
try:
|
385 |
-
spec = torch.load(spec_filename)
|
386 |
-
except Exception as e:
|
387 |
-
logger.error(translations["spec_error"].format(spec_filename=spec_filename, e=e))
|
388 |
-
spec = torch.squeeze(spectrogram_torch(audio_norm, self.filter_length, self.hop_length, self.win_length, center=False), 0)
|
389 |
-
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
390 |
-
else:
|
391 |
-
spec = torch.squeeze(spectrogram_torch(audio_norm, self.filter_length, self.hop_length, self.win_length, center=False), 0)
|
392 |
-
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
393 |
-
return spec, audio_norm
|
394 |
-
|
395 |
-
def __getitem__(self, index):
|
396 |
-
return self.get_audio_text_pair(self.audiopaths_and_text[index])
|
397 |
-
|
398 |
-
def __len__(self):
|
399 |
-
return len(self.audiopaths_and_text)
|
400 |
-
|
401 |
-
class TextAudioCollateMultiNSFsid:
|
402 |
-
def __init__(self, return_ids=False):
|
403 |
-
self.return_ids = return_ids
|
404 |
-
|
405 |
-
def __call__(self, batch):
|
406 |
-
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True)
|
407 |
-
spec_lengths, wave_lengths = torch.LongTensor(len(batch)), torch.LongTensor(len(batch))
|
408 |
-
spec_padded, wave_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max([x[0].size(1) for x in batch])), torch.FloatTensor(len(batch), 1, max([x[1].size(1) for x in batch]))
|
409 |
-
spec_padded.zero_()
|
410 |
-
wave_padded.zero_()
|
411 |
-
max_phone_len = max([x[2].size(0) for x in batch])
|
412 |
-
phone_lengths, phone_padded = torch.LongTensor(len(batch)), torch.FloatTensor(len(batch), max_phone_len, batch[0][2].shape[1])
|
413 |
-
pitch_padded, pitchf_padded = torch.LongTensor(len(batch), max_phone_len), torch.FloatTensor(len(batch), max_phone_len)
|
414 |
-
phone_padded.zero_()
|
415 |
-
pitch_padded.zero_()
|
416 |
-
pitchf_padded.zero_()
|
417 |
-
sid = torch.LongTensor(len(batch))
|
418 |
-
|
419 |
-
for i in range(len(ids_sorted_decreasing)):
|
420 |
-
row = batch[ids_sorted_decreasing[i]]
|
421 |
-
spec = row[0]
|
422 |
-
spec_padded[i, :, : spec.size(1)] = spec
|
423 |
-
spec_lengths[i] = spec.size(1)
|
424 |
-
wave = row[1]
|
425 |
-
wave_padded[i, :, : wave.size(1)] = wave
|
426 |
-
wave_lengths[i] = wave.size(1)
|
427 |
-
phone = row[2]
|
428 |
-
phone_padded[i, : phone.size(0), :] = phone
|
429 |
-
phone_lengths[i] = phone.size(0)
|
430 |
-
pitch = row[3]
|
431 |
-
pitch_padded[i, : pitch.size(0)] = pitch
|
432 |
-
pitchf = row[4]
|
433 |
-
pitchf_padded[i, : pitchf.size(0)] = pitchf
|
434 |
-
sid[i] = row[5]
|
435 |
-
return (phone_padded, phone_lengths, pitch_padded, pitchf_padded, spec_padded, spec_lengths, wave_padded, wave_lengths, sid)
|
436 |
-
|
437 |
-
class TextAudioLoader(tdata.Dataset):
|
438 |
-
def __init__(self, hparams):
|
439 |
-
self.audiopaths_and_text = load_filepaths_and_text(hparams.training_files)
|
440 |
-
self.max_wav_value = hparams.max_wav_value
|
441 |
-
self.sample_rate = hparams.sample_rate
|
442 |
-
self.filter_length = hparams.filter_length
|
443 |
-
self.hop_length = hparams.hop_length
|
444 |
-
self.win_length = hparams.win_length
|
445 |
-
self.sample_rate = hparams.sample_rate
|
446 |
-
self.min_text_len = getattr(hparams, "min_text_len", 1)
|
447 |
-
self.max_text_len = getattr(hparams, "max_text_len", 5000)
|
448 |
-
self._filter()
|
449 |
-
|
450 |
-
def _filter(self):
|
451 |
-
audiopaths_and_text_new, lengths = [], []
|
452 |
-
for entry in self.audiopaths_and_text:
|
453 |
-
if len(entry) >= 3:
|
454 |
-
audiopath, text, dv = entry[:3]
|
455 |
-
if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
|
456 |
-
audiopaths_and_text_new.append([audiopath, text, dv])
|
457 |
-
lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
|
458 |
-
|
459 |
-
self.audiopaths_and_text = audiopaths_and_text_new
|
460 |
-
self.lengths = lengths
|
461 |
-
|
462 |
-
def get_sid(self, sid):
|
463 |
-
try:
|
464 |
-
sid = torch.LongTensor([int(sid)])
|
465 |
-
except ValueError as e:
|
466 |
-
logger.error(translations["sid_error"].format(sid=sid, e=e))
|
467 |
-
sid = torch.LongTensor([0])
|
468 |
-
return sid
|
469 |
-
|
470 |
-
def get_audio_text_pair(self, audiopath_and_text):
|
471 |
-
phone = self.get_labels(audiopath_and_text[1])
|
472 |
-
spec, wav = self.get_audio(audiopath_and_text[0])
|
473 |
-
dv = self.get_sid(audiopath_and_text[2])
|
474 |
-
len_phone = phone.size()[0]
|
475 |
-
len_spec = spec.size()[-1]
|
476 |
-
|
477 |
-
if len_phone != len_spec:
|
478 |
-
len_min = min(len_phone, len_spec)
|
479 |
-
len_wav = len_min * self.hop_length
|
480 |
-
spec = spec[:, :len_min]
|
481 |
-
wav = wav[:, :len_wav]
|
482 |
-
phone = phone[:len_min, :]
|
483 |
-
return (spec, wav, phone, dv)
|
484 |
-
|
485 |
-
def get_labels(self, phone):
|
486 |
-
phone = np.repeat(np.load(phone), 2, axis=0)
|
487 |
-
return torch.FloatTensor(phone[:min(phone.shape[0], 900), :])
|
488 |
-
|
489 |
-
def get_audio(self, filename):
|
490 |
-
audio, sample_rate = load_wav_to_torch(filename)
|
491 |
-
if sample_rate != self.sample_rate: raise ValueError(translations["sr_does_not_match"].format(sample_rate=sample_rate, sample_rate2=self.sample_rate))
|
492 |
-
audio_norm = audio.unsqueeze(0)
|
493 |
-
spec_filename = filename.replace(".wav", ".spec.pt")
|
494 |
-
|
495 |
-
if os.path.exists(spec_filename):
|
496 |
-
try:
|
497 |
-
spec = torch.load(spec_filename)
|
498 |
-
except Exception as e:
|
499 |
-
logger.error(translations["spec_error"].format(spec_filename=spec_filename, e=e))
|
500 |
-
spec = torch.squeeze(spectrogram_torch(audio_norm, self.filter_length, self.hop_length, self.win_length, center=False), 0)
|
501 |
-
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
502 |
-
else:
|
503 |
-
spec = torch.squeeze(spectrogram_torch(audio_norm, self.filter_length, self.hop_length, self.win_length, center=False), 0)
|
504 |
-
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
505 |
-
return spec, audio_norm
|
506 |
-
|
507 |
-
def __getitem__(self, index):
|
508 |
-
return self.get_audio_text_pair(self.audiopaths_and_text[index])
|
509 |
-
|
510 |
-
def __len__(self):
|
511 |
-
return len(self.audiopaths_and_text)
|
512 |
-
|
513 |
-
class TextAudioCollate:
|
514 |
-
def __init__(self, return_ids=False):
|
515 |
-
self.return_ids = return_ids
|
516 |
-
|
517 |
-
def __call__(self, batch):
|
518 |
-
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True)
|
519 |
-
spec_lengths, wave_lengths = torch.LongTensor(len(batch)), torch.LongTensor(len(batch))
|
520 |
-
spec_padded, wave_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max([x[0].size(1) for x in batch])), torch.FloatTensor(len(batch), 1, max([x[1].size(1) for x in batch]))
|
521 |
-
spec_padded.zero_()
|
522 |
-
wave_padded.zero_()
|
523 |
-
max_phone_len = max([x[2].size(0) for x in batch])
|
524 |
-
phone_lengths, phone_padded = torch.LongTensor(len(batch)), torch.FloatTensor(len(batch), max_phone_len, batch[0][2].shape[1])
|
525 |
-
phone_padded.zero_()
|
526 |
-
sid = torch.LongTensor(len(batch))
|
527 |
-
for i in range(len(ids_sorted_decreasing)):
|
528 |
-
row = batch[ids_sorted_decreasing[i]]
|
529 |
-
spec = row[0]
|
530 |
-
spec_padded[i, :, : spec.size(1)] = spec
|
531 |
-
spec_lengths[i] = spec.size(1)
|
532 |
-
wave = row[1]
|
533 |
-
wave_padded[i, :, : wave.size(1)] = wave
|
534 |
-
wave_lengths[i] = wave.size(1)
|
535 |
-
phone = row[2]
|
536 |
-
phone_padded[i, : phone.size(0), :] = phone
|
537 |
-
phone_lengths[i] = phone.size(0)
|
538 |
-
sid[i] = row[3]
|
539 |
-
return (phone_padded, phone_lengths, spec_padded, spec_lengths, wave_padded, wave_lengths, sid)
|
540 |
-
|
541 |
-
class DistributedBucketSampler(tdata.distributed.DistributedSampler):
|
542 |
-
def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
|
543 |
-
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
544 |
-
self.lengths = dataset.lengths
|
545 |
-
self.batch_size = batch_size
|
546 |
-
self.boundaries = boundaries
|
547 |
-
self.buckets, self.num_samples_per_bucket = self._create_buckets()
|
548 |
-
self.total_size = sum(self.num_samples_per_bucket)
|
549 |
-
self.num_samples = self.total_size // self.num_replicas
|
550 |
-
|
551 |
-
def _create_buckets(self):
|
552 |
-
buckets = [[] for _ in range(len(self.boundaries) - 1)]
|
553 |
-
for i in range(len(self.lengths)):
|
554 |
-
idx_bucket = self._bisect(self.lengths[i])
|
555 |
-
if idx_bucket != -1: buckets[idx_bucket].append(i)
|
556 |
-
|
557 |
-
for i in range(len(buckets) - 1, -1, -1):
|
558 |
-
if len(buckets[i]) == 0:
|
559 |
-
buckets.pop(i)
|
560 |
-
self.boundaries.pop(i + 1)
|
561 |
-
|
562 |
-
num_samples_per_bucket = []
|
563 |
-
for i in range(len(buckets)):
|
564 |
-
len_bucket = len(buckets[i])
|
565 |
-
total_batch_size = self.num_replicas * self.batch_size
|
566 |
-
num_samples_per_bucket.append(len_bucket + ((total_batch_size - (len_bucket % total_batch_size)) % total_batch_size))
|
567 |
-
return buckets, num_samples_per_bucket
|
568 |
-
|
569 |
-
def __iter__(self):
|
570 |
-
g = torch.Generator()
|
571 |
-
g.manual_seed(self.epoch)
|
572 |
-
indices, batches = [], []
|
573 |
-
if self.shuffle:
|
574 |
-
for bucket in self.buckets:
|
575 |
-
indices.append(torch.randperm(len(bucket), generator=g).tolist())
|
576 |
-
else:
|
577 |
-
for bucket in self.buckets:
|
578 |
-
indices.append(list(range(len(bucket))))
|
579 |
-
|
580 |
-
for i in range(len(self.buckets)):
|
581 |
-
bucket = self.buckets[i]
|
582 |
-
len_bucket = len(bucket)
|
583 |
-
ids_bucket = indices[i]
|
584 |
-
rem = self.num_samples_per_bucket[i] - len_bucket
|
585 |
-
ids_bucket = (ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[: (rem % len_bucket)])[self.rank :: self.num_replicas]
|
586 |
-
|
587 |
-
for j in range(len(ids_bucket) // self.batch_size):
|
588 |
-
batches.append([bucket[idx] for idx in ids_bucket[j * self.batch_size : (j + 1) * self.batch_size]])
|
589 |
-
|
590 |
-
if self.shuffle: batches = [batches[i] for i in torch.randperm(len(batches), generator=g).tolist()]
|
591 |
-
self.batches = batches
|
592 |
-
assert len(self.batches) * self.batch_size == self.num_samples
|
593 |
-
return iter(self.batches)
|
594 |
-
|
595 |
-
def _bisect(self, x, lo=0, hi=None):
|
596 |
-
if hi is None: hi = len(self.boundaries) - 1
|
597 |
-
|
598 |
-
if hi > lo:
|
599 |
-
mid = (hi + lo) // 2
|
600 |
-
if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: return mid
|
601 |
-
elif x <= self.boundaries[mid]: return self._bisect(x, lo, mid)
|
602 |
-
else: return self._bisect(x, mid + 1, hi)
|
603 |
-
else: return -1
|
604 |
-
|
605 |
-
def __len__(self):
|
606 |
-
return self.num_samples // self.batch_size
|
607 |
-
|
608 |
-
class MultiPeriodDiscriminator(torch.nn.Module):
|
609 |
-
def __init__(self, version, use_spectral_norm=False, checkpointing=False):
|
610 |
-
super(MultiPeriodDiscriminator, self).__init__()
|
611 |
-
self.checkpointing = checkpointing
|
612 |
-
periods = ([2, 3, 5, 7, 11, 17] if version == "v1" else [2, 3, 5, 7, 11, 17, 23, 37])
|
613 |
-
self.discriminators = torch.nn.ModuleList([DiscriminatorS(use_spectral_norm=use_spectral_norm, checkpointing=checkpointing)] + [DiscriminatorP(p, use_spectral_norm=use_spectral_norm, checkpointing=checkpointing) for p in periods])
|
614 |
-
|
615 |
-
def forward(self, y, y_hat):
|
616 |
-
y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
|
617 |
-
for d in self.discriminators:
|
618 |
-
if self.training and self.checkpointing:
|
619 |
-
def forward_discriminator(d, y, y_hat):
|
620 |
-
y_d_r, fmap_r = d(y)
|
621 |
-
y_d_g, fmap_g = d(y_hat)
|
622 |
-
return y_d_r, fmap_r, y_d_g, fmap_g
|
623 |
-
y_d_r, fmap_r, y_d_g, fmap_g = checkpoint(forward_discriminator, d, y, y_hat, use_reentrant=False)
|
624 |
-
else:
|
625 |
-
y_d_r, fmap_r = d(y)
|
626 |
-
y_d_g, fmap_g = d(y_hat)
|
627 |
-
|
628 |
-
y_d_rs.append(y_d_r); fmap_rs.append(fmap_r)
|
629 |
-
y_d_gs.append(y_d_g); fmap_gs.append(fmap_g)
|
630 |
-
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
631 |
-
|
632 |
-
class DiscriminatorS(torch.nn.Module):
|
633 |
-
def __init__(self, use_spectral_norm=False, checkpointing=False):
|
634 |
-
super(DiscriminatorS, self).__init__()
|
635 |
-
self.checkpointing = checkpointing
|
636 |
-
norm_f = spectral_norm if use_spectral_norm else weight_norm
|
637 |
-
self.convs = torch.nn.ModuleList([norm_f(torch.nn.Conv1d(1, 16, 15, 1, padding=7)), norm_f(torch.nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)), norm_f(torch.nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)), norm_f(torch.nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)), norm_f(torch.nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), norm_f(torch.nn.Conv1d(1024, 1024, 5, 1, padding=2))])
|
638 |
-
self.conv_post = norm_f(torch.nn.Conv1d(1024, 1, 3, 1, padding=1))
|
639 |
-
self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE)
|
640 |
-
|
641 |
-
def forward(self, x):
|
642 |
-
fmap = []
|
643 |
-
for conv in self.convs:
|
644 |
-
x = checkpoint(self.lrelu, checkpoint(conv, x, use_reentrant = False), use_reentrant = False) if self.training and self.checkpointing else self.lrelu(conv(x))
|
645 |
-
fmap.append(x)
|
646 |
-
|
647 |
-
x = self.conv_post(x)
|
648 |
-
fmap.append(x)
|
649 |
-
return torch.flatten(x, 1, -1), fmap
|
650 |
-
|
651 |
-
class DiscriminatorP(torch.nn.Module):
|
652 |
-
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, checkpointing=False):
|
653 |
-
super(DiscriminatorP, self).__init__()
|
654 |
-
self.period = period
|
655 |
-
self.checkpointing = checkpointing
|
656 |
-
norm_f = spectral_norm if use_spectral_norm else weight_norm
|
657 |
-
self.convs = torch.nn.ModuleList([norm_f(torch.nn.Conv2d(in_ch, out_ch, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))) for in_ch, out_ch in zip([1, 32, 128, 512, 1024], [32, 128, 512, 1024, 1024])])
|
658 |
-
self.conv_post = norm_f(torch.nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
659 |
-
self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE)
|
660 |
-
|
661 |
-
def forward(self, x):
|
662 |
-
fmap = []
|
663 |
-
b, c, t = x.shape
|
664 |
-
|
665 |
-
if t % self.period != 0: x = F.pad(x, (0, (self.period - (t % self.period))), "reflect")
|
666 |
-
x = x.view(b, c, -1, self.period)
|
667 |
-
|
668 |
-
for conv in self.convs:
|
669 |
-
x = checkpoint(self.lrelu, checkpoint(conv, x, use_reentrant = False), use_reentrant = False) if self.training and self.checkpointing else self.lrelu(conv(x))
|
670 |
-
fmap.append(x)
|
671 |
-
|
672 |
-
x = self.conv_post(x)
|
673 |
-
fmap.append(x)
|
674 |
-
return torch.flatten(x, 1, -1), fmap
|
675 |
-
|
676 |
-
class EpochRecorder:
|
677 |
-
def __init__(self):
|
678 |
-
self.last_time = ttime()
|
679 |
-
|
680 |
-
def record(self):
|
681 |
-
now_time = ttime()
|
682 |
-
elapsed_time = now_time - self.last_time
|
683 |
-
self.last_time = now_time
|
684 |
-
return translations["time_or_speed_training"].format(current_time=datetime.datetime.now().strftime("%H:%M:%S"), elapsed_time_str=str(datetime.timedelta(seconds=int(round(elapsed_time, 1)))))
|
685 |
-
|
686 |
-
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
687 |
-
return torch.log(torch.clamp(x, min=clip_val) * C)
|
688 |
-
|
689 |
-
def dynamic_range_decompression_torch(x, C=1):
|
690 |
-
return torch.exp(x) / C
|
691 |
-
|
692 |
-
def spectral_normalize_torch(magnitudes):
|
693 |
-
return dynamic_range_compression_torch(magnitudes)
|
694 |
-
|
695 |
-
def spectral_de_normalize_torch(magnitudes):
|
696 |
-
return dynamic_range_decompression_torch(magnitudes)
|
697 |
-
|
698 |
-
mel_basis, hann_window = {}, {}
|
699 |
-
|
700 |
-
def spectrogram_torch(y, n_fft, hop_size, win_size, center=False):
|
701 |
-
global hann_window
|
702 |
-
|
703 |
-
wnsize_dtype_device = str(win_size) + "_" + str(y.dtype) + "_" + str(y.device)
|
704 |
-
if wnsize_dtype_device not in hann_window: hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
705 |
-
spec = torch.stft(F.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect").squeeze(1), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
|
706 |
-
return torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6)
|
707 |
-
|
708 |
-
def spec_to_mel_torch(spec, n_fft, num_mels, sample_rate, fmin, fmax):
|
709 |
-
global mel_basis
|
710 |
-
|
711 |
-
fmax_dtype_device = str(fmax) + "_" + str(spec.dtype) + "_" + str(spec.device)
|
712 |
-
if fmax_dtype_device not in mel_basis: mel_basis[fmax_dtype_device] = torch.from_numpy(librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)).to(dtype=spec.dtype, device=spec.device)
|
713 |
-
return spectral_normalize_torch(torch.matmul(mel_basis[fmax_dtype_device], spec))
|
714 |
-
|
715 |
-
def mel_spectrogram_torch(y, n_fft, num_mels, sample_rate, hop_size, win_size, fmin, fmax, center=False):
|
716 |
-
return spec_to_mel_torch(spectrogram_torch(y, n_fft, hop_size, win_size, center), n_fft, num_mels, sample_rate, fmin, fmax)
|
717 |
-
|
718 |
-
def replace_keys_in_dict(d, old_key_part, new_key_part):
|
719 |
-
updated_dict = OrderedDict() if isinstance(d, OrderedDict) else {}
|
720 |
-
for key, value in d.items():
|
721 |
-
updated_dict[(key.replace(old_key_part, new_key_part) if isinstance(key, str) else key)] = (replace_keys_in_dict(value, old_key_part, new_key_part) if isinstance(value, dict) else value)
|
722 |
-
return updated_dict
|
723 |
-
|
724 |
-
def extract_model(ckpt, sr, pitch_guidance, name, model_path, epoch, step, version, hps, model_author, vocoder):
|
725 |
-
try:
|
726 |
-
logger.info(translations["savemodel"].format(model_dir=model_path, epoch=epoch, step=step))
|
727 |
-
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
728 |
-
|
729 |
-
opt = OrderedDict(weight={key: value.half() for key, value in ckpt.items() if "enc_q" not in key})
|
730 |
-
opt["config"] = [hps.data.filter_length // 2 + 1, 32, hps.model.inter_channels, hps.model.hidden_channels, hps.model.filter_channels, hps.model.n_heads, hps.model.n_layers, hps.model.kernel_size, hps.model.p_dropout, hps.model.resblock, hps.model.resblock_kernel_sizes, hps.model.resblock_dilation_sizes, hps.model.upsample_rates, hps.model.upsample_initial_channel, hps.model.upsample_kernel_sizes, hps.model.spk_embed_dim, hps.model.gin_channels, hps.data.sample_rate]
|
731 |
-
opt["epoch"] = f"{epoch}epoch"
|
732 |
-
opt["step"] = step
|
733 |
-
opt["sr"] = sr
|
734 |
-
opt["f0"] = int(pitch_guidance)
|
735 |
-
opt["version"] = version
|
736 |
-
opt["creation_date"] = datetime.datetime.now().isoformat()
|
737 |
-
opt["model_hash"] = hashlib.sha256(f"{str(ckpt)} {epoch} {step} {datetime.datetime.now().isoformat()}".encode()).hexdigest()
|
738 |
-
opt["model_name"] = name
|
739 |
-
opt["author"] = model_author
|
740 |
-
opt["vocoder"] = vocoder
|
741 |
-
|
742 |
-
torch.save(replace_keys_in_dict(replace_keys_in_dict(opt, ".parametrizations.weight.original1", ".weight_v"), ".parametrizations.weight.original0", ".weight_g"), model_path)
|
743 |
-
except Exception as e:
|
744 |
-
logger.error(f"{translations['extract_model_error']}: {e}")
|
745 |
-
|
746 |
-
def run(rank, n_gpus, experiment_dir, pretrainG, pretrainD, pitch_guidance, custom_total_epoch, custom_save_every_weights, config, device, device_id, model_author, vocoder, checkpointing):
|
747 |
-
global global_step
|
748 |
-
|
749 |
-
if rank == 0: writer_eval = SummaryWriter(log_dir=os.path.join(experiment_dir, "eval"))
|
750 |
-
else: writer_eval = None
|
751 |
-
|
752 |
-
try:
|
753 |
-
dist.init_process_group(backend=("gloo" if sys.platform == "win32" or device.type != "cuda" else "nccl"), init_method="env://", world_size=n_gpus, rank=rank)
|
754 |
-
except:
|
755 |
-
dist.init_process_group(backend=("gloo" if sys.platform == "win32" or device.type != "cuda" else "nccl"), init_method="env://?use_libuv=False", world_size=n_gpus, rank=rank)
|
756 |
-
|
757 |
-
torch.manual_seed(config.train.seed)
|
758 |
-
if torch.cuda.is_available(): torch.cuda.set_device(device_id)
|
759 |
-
|
760 |
-
train_dataset = TextAudioLoaderMultiNSFsid(config.data) if pitch_guidance else TextAudioLoader(config.data)
|
761 |
-
train_loader = tdata.DataLoader(train_dataset, num_workers=4, shuffle=False, pin_memory=True, collate_fn=TextAudioCollateMultiNSFsid() if pitch_guidance else TextAudioCollate(), batch_sampler=DistributedBucketSampler(train_dataset, batch_size * n_gpus, [100, 200, 300, 400, 500, 600, 700, 800, 900], num_replicas=n_gpus, rank=rank, shuffle=True), persistent_workers=True, prefetch_factor=8)
|
762 |
-
|
763 |
-
net_g, net_d = Synthesizer(config.data.filter_length // 2 + 1, config.train.segment_size // config.data.hop_length, **config.model, use_f0=pitch_guidance, sr=sample_rate, vocoder=vocoder, checkpointing=checkpointing), MultiPeriodDiscriminator(version, config.model.use_spectral_norm, checkpointing=checkpointing)
|
764 |
-
net_g, net_d = (net_g.cuda(device_id), net_d.cuda(device_id)) if torch.cuda.is_available() else (net_g.to(device), net_d.to(device))
|
765 |
-
|
766 |
-
optim_g, optim_d = torch.optim.AdamW(net_g.parameters(), config.train.learning_rate, betas=config.train.betas, eps=config.train.eps), torch.optim.AdamW(net_d.parameters(), config.train.learning_rate, betas=config.train.betas, eps=config.train.eps)
|
767 |
-
net_g, net_d = (DDP(net_g, device_ids=[device_id]), DDP(net_d, device_ids=[device_id])) if torch.cuda.is_available() else (DDP(net_g), DDP(net_d))
|
768 |
-
|
769 |
-
try:
|
770 |
-
logger.info(translations["start_training"])
|
771 |
-
_, _, _, epoch_str = load_checkpoint((os.path.join(experiment_dir, "D_latest.pth") if save_only_latest else latest_checkpoint_path(experiment_dir, "D_*.pth")), net_d, optim_d)
|
772 |
-
_, _, _, epoch_str = load_checkpoint((os.path.join(experiment_dir, "G_latest.pth") if save_only_latest else latest_checkpoint_path(experiment_dir, "G_*.pth")), net_g, optim_g)
|
773 |
-
epoch_str += 1
|
774 |
-
global_step = (epoch_str - 1) * len(train_loader)
|
775 |
-
except:
|
776 |
-
epoch_str, global_step = 1, 0
|
777 |
-
|
778 |
-
if pretrainG != "" and pretrainG != "None":
|
779 |
-
if rank == 0:
|
780 |
-
verify_checkpoint_shapes(pretrainG, net_g)
|
781 |
-
logger.info(translations["import_pretrain"].format(dg="G", pretrain=pretrainG))
|
782 |
-
|
783 |
-
if hasattr(net_g, "module"): net_g.module.load_state_dict(torch.load(pretrainG, map_location="cpu")["model"])
|
784 |
-
else: net_g.load_state_dict(torch.load(pretrainG, map_location="cpu")["model"])
|
785 |
-
else: logger.warning(translations["not_using_pretrain"].format(dg="G"))
|
786 |
-
|
787 |
-
if pretrainD != "" and pretrainD != "None":
|
788 |
-
if rank == 0:
|
789 |
-
verify_checkpoint_shapes(pretrainD, net_d)
|
790 |
-
logger.info(translations["import_pretrain"].format(dg="D", pretrain=pretrainD))
|
791 |
-
|
792 |
-
if hasattr(net_d, "module"): net_d.module.load_state_dict(torch.load(pretrainD, map_location="cpu")["model"])
|
793 |
-
else: net_d.load_state_dict(torch.load(pretrainD, map_location="cpu")["model"])
|
794 |
-
else: logger.warning(translations["not_using_pretrain"].format(dg="D"))
|
795 |
-
|
796 |
-
scheduler_g, scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.train.lr_decay, last_epoch=epoch_str - 2), torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.train.lr_decay, last_epoch=epoch_str - 2)
|
797 |
-
optim_d.step(); optim_g.step()
|
798 |
-
|
799 |
-
scaler = GradScaler(enabled=main_config.is_half and device.type == "cuda")
|
800 |
-
cache = []
|
801 |
-
|
802 |
-
for info in train_loader:
|
803 |
-
phone, phone_lengths, pitch, pitchf, _, _, _, _, sid = info
|
804 |
-
reference = (phone.cuda(device_id, non_blocking=True), phone_lengths.cuda(device_id, non_blocking=True), (pitch.cuda(device_id, non_blocking=True) if pitch_guidance else None), (pitchf.cuda(device_id, non_blocking=True) if pitch_guidance else None), sid.cuda(device_id, non_blocking=True)) if device.type == "cuda" else (phone.to(device), phone_lengths.to(device), (pitch.to(device) if pitch_guidance else None), (pitchf.to(device) if pitch_guidance else None), sid.to(device))
|
805 |
-
break
|
806 |
-
|
807 |
-
for epoch in range(epoch_str, total_epoch + 1):
|
808 |
-
train_and_evaluate(rank, epoch, config, [net_g, net_d], [optim_g, optim_d], scaler, train_loader, writer_eval, cache, custom_save_every_weights, custom_total_epoch, device, device_id, reference, model_author, vocoder)
|
809 |
-
scheduler_g.step(); scheduler_d.step()
|
810 |
-
|
811 |
-
def train_and_evaluate(rank, epoch, hps, nets, optims, scaler, train_loader, writer, cache, custom_save_every_weights, custom_total_epoch, device, device_id, reference, model_author, vocoder):
|
812 |
-
global global_step, lowest_value, loss_disc, consecutive_increases_gen, consecutive_increases_disc
|
813 |
-
|
814 |
-
if epoch == 1:
|
815 |
-
lowest_value = {"step": 0, "value": float("inf"), "epoch": 0}
|
816 |
-
last_loss_gen_all, consecutive_increases_gen, consecutive_increases_disc = 0.0, 0, 0
|
817 |
-
|
818 |
-
net_g, net_d = nets
|
819 |
-
optim_g, optim_d = optims
|
820 |
-
train_loader.batch_sampler.set_epoch(epoch)
|
821 |
-
|
822 |
-
net_g.train(); net_d.train()
|
823 |
-
|
824 |
-
if device.type == "cuda" and cache_data_in_gpu:
|
825 |
-
data_iterator = cache
|
826 |
-
if cache == []:
|
827 |
-
for batch_idx, info in enumerate(train_loader):
|
828 |
-
cache.append((batch_idx, [tensor.cuda(device_id, non_blocking=True) for tensor in info]))
|
829 |
-
else: shuffle(cache)
|
830 |
-
else: data_iterator = enumerate(train_loader)
|
831 |
-
|
832 |
-
epoch_recorder = EpochRecorder()
|
833 |
-
with tqdm(total=len(train_loader), leave=False) as pbar:
|
834 |
-
for batch_idx, info in data_iterator:
|
835 |
-
if device.type == "cuda" and not cache_data_in_gpu: info = [tensor.cuda(device_id, non_blocking=True) for tensor in info]
|
836 |
-
elif device.type != "cuda": info = [tensor.to(device) for tensor in info]
|
837 |
-
|
838 |
-
phone, phone_lengths, pitch, pitchf, spec, spec_lengths, wave, _, sid = info
|
839 |
-
pitch = pitch if pitch_guidance else None
|
840 |
-
pitchf = pitchf if pitch_guidance else None
|
841 |
-
|
842 |
-
with autocast(enabled=main_config.is_half and device.type == "cuda"):
|
843 |
-
y_hat, ids_slice, _, z_mask, (_, z_p, m_p, logs_p, _, logs_q) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
|
844 |
-
mel = spec_to_mel_torch(spec, config.data.filter_length, config.data.n_mel_channels, config.data.sample_rate, config.data.mel_fmin, config.data.mel_fmax)
|
845 |
-
y_mel = slice_segments(mel, ids_slice, config.train.segment_size // config.data.hop_length, dim=3)
|
846 |
-
|
847 |
-
with autocast(enabled=main_config.is_half and device.type == "cuda"):
|
848 |
-
y_hat_mel = mel_spectrogram_torch(y_hat.float().squeeze(1), config.data.filter_length, config.data.n_mel_channels, config.data.sample_rate, config.data.hop_length, config.data.win_length, config.data.mel_fmin, config.data.mel_fmax)
|
849 |
-
|
850 |
-
wave = slice_segments(wave, ids_slice * config.data.hop_length, config.train.segment_size, dim=3)
|
851 |
-
y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
|
852 |
-
|
853 |
-
with autocast(enabled=main_config.is_half and device.type == "cuda"):
|
854 |
-
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
|
855 |
-
|
856 |
-
optim_d.zero_grad()
|
857 |
-
scaler.scale(loss_disc).backward()
|
858 |
-
scaler.unscale_(optim_d)
|
859 |
-
grad_norm_d = clip_grad_value(net_d.parameters(), None)
|
860 |
-
scaler.step(optim_d)
|
861 |
-
|
862 |
-
with autocast(enabled=main_config.is_half and device.type == "cuda"):
|
863 |
-
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
|
864 |
-
with autocast(enabled=main_config.is_half and device.type == "cuda"):
|
865 |
-
loss_mel = F.l1_loss(y_mel, y_hat_mel) * config.train.c_mel
|
866 |
-
loss_kl = (kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl)
|
867 |
-
loss_fm = feature_loss(fmap_r, fmap_g)
|
868 |
-
loss_gen, losses_gen = generator_loss(y_d_hat_g)
|
869 |
-
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
|
870 |
-
if loss_gen_all < lowest_value["value"]:
|
871 |
-
lowest_value["value"] = loss_gen_all
|
872 |
-
lowest_value["step"] = global_step
|
873 |
-
lowest_value["epoch"] = epoch
|
874 |
-
if epoch > lowest_value["epoch"]: logger.warning(translations["training_warning"])
|
875 |
-
|
876 |
-
optim_g.zero_grad()
|
877 |
-
scaler.scale(loss_gen_all).backward()
|
878 |
-
scaler.unscale_(optim_g)
|
879 |
-
grad_norm_g = clip_grad_value(net_g.parameters(), None)
|
880 |
-
scaler.step(optim_g)
|
881 |
-
scaler.update()
|
882 |
-
|
883 |
-
if rank == 0 and global_step % config.train.log_interval == 0:
|
884 |
-
if loss_mel > 75: loss_mel = 75
|
885 |
-
if loss_kl > 9: loss_kl = 9
|
886 |
-
|
887 |
-
scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc, "learning_rate": optim_g.param_groups[0]["lr"], "grad/norm_d": grad_norm_d, "grad/norm_g": grad_norm_g, "loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl}
|
888 |
-
scalar_dict.update({f"loss/g/{i}": v for i, v in enumerate(losses_gen)})
|
889 |
-
scalar_dict.update({f"loss/d_r/{i}": v for i, v in enumerate(losses_disc_r)})
|
890 |
-
scalar_dict.update({f"loss/d_g/{i}": v for i, v in enumerate(losses_disc_g)})
|
891 |
-
|
892 |
-
with torch.no_grad():
|
893 |
-
o, *_ = net_g.module.infer(*reference) if hasattr(net_g, "module") else net_g.infer(*reference)
|
894 |
-
|
895 |
-
summarize(writer=writer, global_step=global_step, images={"slice/mel_org": plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), "slice/mel_gen": plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), "all/mel": plot_spectrogram_to_numpy(mel[0].data.cpu().numpy())}, scalars=scalar_dict, audios={f"gen/audio_{global_step:07d}": o[0, :, :]}, audio_sample_rate=config.data.sample_rate)
|
896 |
-
|
897 |
-
global_step += 1
|
898 |
-
pbar.update(1)
|
899 |
-
|
900 |
-
def check_overtraining(smoothed_loss_history, threshold, epsilon=0.004):
|
901 |
-
if len(smoothed_loss_history) < threshold + 1: return False
|
902 |
-
for i in range(-threshold, -1):
|
903 |
-
if smoothed_loss_history[i + 1] > smoothed_loss_history[i]: return True
|
904 |
-
if abs(smoothed_loss_history[i + 1] - smoothed_loss_history[i]) >= epsilon: return False
|
905 |
-
return True
|
906 |
-
|
907 |
-
def update_exponential_moving_average(smoothed_loss_history, new_value, smoothing=0.987):
|
908 |
-
smoothed_value = new_value if not smoothed_loss_history else (smoothing * smoothed_loss_history[-1] + (1 - smoothing) * new_value)
|
909 |
-
smoothed_loss_history.append(smoothed_value)
|
910 |
-
return smoothed_value
|
911 |
-
|
912 |
-
def save_to_json(file_path, loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history):
|
913 |
-
with open(file_path, "w") as f:
|
914 |
-
json.dump({"loss_disc_history": loss_disc_history, "smoothed_loss_disc_history": smoothed_loss_disc_history, "loss_gen_history": loss_gen_history, "smoothed_loss_gen_history": smoothed_loss_gen_history}, f)
|
915 |
-
|
916 |
-
model_add, model_del = [], []
|
917 |
-
done = False
|
918 |
-
|
919 |
-
if rank == 0:
|
920 |
-
if epoch % save_every_epoch == False:
|
921 |
-
checkpoint_suffix = f"{'latest' if save_only_latest else global_step}.pth"
|
922 |
-
save_checkpoint(net_g, optim_g, config.train.learning_rate, epoch, os.path.join(experiment_dir, "G_" + checkpoint_suffix))
|
923 |
-
save_checkpoint(net_d, optim_d, config.train.learning_rate, epoch, os.path.join(experiment_dir, "D_" + checkpoint_suffix))
|
924 |
-
if custom_save_every_weights: model_add.append(os.path.join("assets", "weights", f"{model_name}_{epoch}e_{global_step}s.pth"))
|
925 |
-
|
926 |
-
if overtraining_detector and epoch > 1:
|
927 |
-
current_loss_disc = float(loss_disc)
|
928 |
-
loss_disc_history.append(current_loss_disc)
|
929 |
-
smoothed_value_disc = update_exponential_moving_average(smoothed_loss_disc_history, current_loss_disc)
|
930 |
-
is_overtraining_disc = check_overtraining(smoothed_loss_disc_history, overtraining_threshold * 2)
|
931 |
-
|
932 |
-
if is_overtraining_disc: consecutive_increases_disc += 1
|
933 |
-
else: consecutive_increases_disc = 0
|
934 |
-
|
935 |
-
current_loss_gen = float(lowest_value["value"])
|
936 |
-
loss_gen_history.append(current_loss_gen)
|
937 |
-
smoothed_value_gen = update_exponential_moving_average(smoothed_loss_gen_history, current_loss_gen)
|
938 |
-
is_overtraining_gen = check_overtraining(smoothed_loss_gen_history, overtraining_threshold, 0.01)
|
939 |
-
|
940 |
-
if is_overtraining_gen: consecutive_increases_gen += 1
|
941 |
-
else: consecutive_increases_gen = 0
|
942 |
-
|
943 |
-
if epoch % save_every_epoch == 0: save_to_json(training_file_path, loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history)
|
944 |
-
|
945 |
-
if (is_overtraining_gen and consecutive_increases_gen == overtraining_threshold or is_overtraining_disc and consecutive_increases_disc == (overtraining_threshold * 2)):
|
946 |
-
logger.info(translations["overtraining_find"].format(epoch=epoch, smoothed_value_gen=f"{smoothed_value_gen:.3f}", smoothed_value_disc=f"{smoothed_value_disc:.3f}"))
|
947 |
-
done = True
|
948 |
-
else:
|
949 |
-
logger.info(translations["best_epoch"].format(epoch=epoch, smoothed_value_gen=f"{smoothed_value_gen:.3f}", smoothed_value_disc=f"{smoothed_value_disc:.3f}"))
|
950 |
-
for file in glob.glob(os.path.join("assets", "weights", f"{model_name}_*e_*s_best_epoch.pth")):
|
951 |
-
model_del.append(file)
|
952 |
-
|
953 |
-
model_add.append(os.path.join("assets", "weights", f"{model_name}_{epoch}e_{global_step}s_best_epoch.pth"))
|
954 |
-
|
955 |
-
if epoch >= custom_total_epoch:
|
956 |
-
logger.info(translations["success_training"].format(epoch=epoch, global_step=global_step, loss_gen_all=round(loss_gen_all.item(), 3)))
|
957 |
-
logger.info(translations["training_info"].format(lowest_value_rounded=round(float(lowest_value["value"]), 3), lowest_value_epoch=lowest_value['epoch'], lowest_value_step=lowest_value['step']))
|
958 |
-
|
959 |
-
pid_file_path = os.path.join(experiment_dir, "config.json")
|
960 |
-
with open(pid_file_path, "r") as pid_file:
|
961 |
-
pid_data = json.load(pid_file)
|
962 |
-
|
963 |
-
with open(pid_file_path, "w") as pid_file:
|
964 |
-
pid_data.pop("process_pids", None)
|
965 |
-
json.dump(pid_data, pid_file, indent=4)
|
966 |
-
|
967 |
-
if os.path.exists(os.path.join(experiment_dir, "train_pid.txt")): os.remove(os.path.join(experiment_dir, "train_pid.txt"))
|
968 |
-
model_add.append(os.path.join("assets", "weights", f"{model_name}_{epoch}e_{global_step}s.pth"))
|
969 |
-
done = True
|
970 |
-
|
971 |
-
for m in model_del:
|
972 |
-
os.remove(m)
|
973 |
-
|
974 |
-
if model_add:
|
975 |
-
ckpt = (net_g.module.state_dict() if hasattr(net_g, "module") else net_g.state_dict())
|
976 |
-
for m in model_add:
|
977 |
-
extract_model(ckpt=ckpt, sr=sample_rate, pitch_guidance=pitch_guidance == True, name=model_name, model_path=m, epoch=epoch, step=global_step, version=version, hps=hps, model_author=model_author, vocoder=vocoder)
|
978 |
-
|
979 |
-
lowest_value_rounded = round(float(lowest_value["value"]), 3)
|
980 |
-
|
981 |
-
if epoch > 1 and overtraining_detector: logger.info(translations["model_training_info"].format(model_name=model_name, epoch=epoch, global_step=global_step, epoch_recorder=epoch_recorder.record(), lowest_value_rounded=lowest_value_rounded, lowest_value_epoch=lowest_value['epoch'], lowest_value_step=lowest_value['step'], remaining_epochs_gen=(overtraining_threshold - consecutive_increases_gen), remaining_epochs_disc=((overtraining_threshold * 2) - consecutive_increases_disc), smoothed_value_gen=f"{smoothed_value_gen:.3f}", smoothed_value_disc=f"{smoothed_value_disc:.3f}"))
|
982 |
-
elif epoch > 1 and overtraining_detector == False: logger.info(translations["model_training_info_2"].format(model_name=model_name, epoch=epoch, global_step=global_step, epoch_recorder=epoch_recorder.record(), lowest_value_rounded=lowest_value_rounded, lowest_value_epoch=lowest_value['epoch'], lowest_value_step=lowest_value['step']))
|
983 |
-
else: logger.info(translations["model_training_info_3"].format(model_name=model_name, epoch=epoch, global_step=global_step, epoch_recorder=epoch_recorder.record()))
|
984 |
-
|
985 |
-
last_loss_gen_all = loss_gen_all
|
986 |
-
if done: os._exit(0)
|
987 |
-
|
988 |
-
if __name__ == "__main__":
|
989 |
-
torch.multiprocessing.set_start_method("spawn")
|
990 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/commons.py
DELETED
@@ -1,60 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
def init_weights(m, mean=0.0, std=0.01):
|
6 |
-
if m.__class__.__name__.find("Conv") != -1: m.weight.data.normal_(mean, std)
|
7 |
-
|
8 |
-
def get_padding(kernel_size, dilation=1):
|
9 |
-
return int((kernel_size * dilation - dilation) / 2)
|
10 |
-
|
11 |
-
def convert_pad_shape(pad_shape):
|
12 |
-
return [item for sublist in pad_shape[::-1] for item in sublist]
|
13 |
-
|
14 |
-
def slice_segments(x, ids_str, segment_size = 4, dim = 2):
|
15 |
-
if dim == 2: ret = torch.zeros_like(x[:, :segment_size])
|
16 |
-
elif dim == 3: ret = torch.zeros_like(x[:, :, :segment_size])
|
17 |
-
|
18 |
-
for i in range(x.size(0)):
|
19 |
-
idx_str = ids_str[i].item()
|
20 |
-
idx_end = idx_str + segment_size
|
21 |
-
|
22 |
-
if dim == 2: ret[i] = x[i, idx_str:idx_end]
|
23 |
-
else: ret[i] = x[i, :, idx_str:idx_end]
|
24 |
-
|
25 |
-
return ret
|
26 |
-
|
27 |
-
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
28 |
-
b, _, t = x.size()
|
29 |
-
if x_lengths is None: x_lengths = t
|
30 |
-
|
31 |
-
ids_str = (torch.rand([b]).to(device=x.device) * (x_lengths - segment_size + 1)).to(dtype=torch.long)
|
32 |
-
|
33 |
-
return slice_segments(x, ids_str, segment_size, dim=3), ids_str
|
34 |
-
|
35 |
-
@torch.jit.script
|
36 |
-
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
37 |
-
n_channels_int = n_channels[0]
|
38 |
-
|
39 |
-
in_act = input_a + input_b
|
40 |
-
|
41 |
-
return torch.tanh(in_act[:, :n_channels_int, :]) * torch.sigmoid(in_act[:, n_channels_int:, :])
|
42 |
-
|
43 |
-
def sequence_mask(length, max_length = None):
|
44 |
-
if max_length is None: max_length = length.max()
|
45 |
-
|
46 |
-
return torch.arange(max_length, dtype=length.dtype, device=length.device).unsqueeze(0) < length.unsqueeze(1)
|
47 |
-
|
48 |
-
def clip_grad_value(parameters, clip_value, norm_type=2):
|
49 |
-
if isinstance(parameters, torch.Tensor): parameters = [parameters]
|
50 |
-
norm_type = float(norm_type)
|
51 |
-
|
52 |
-
if clip_value is not None: clip_value = float(clip_value)
|
53 |
-
total_norm = 0
|
54 |
-
|
55 |
-
for p in list(filter(lambda p: p.grad is not None, parameters)):
|
56 |
-
total_norm += (p.grad.data.norm(norm_type)).item() ** norm_type
|
57 |
-
|
58 |
-
if clip_value is not None: p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
59 |
-
|
60 |
-
return total_norm ** (1.0 / norm_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/modules.py
DELETED
@@ -1,60 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import torch
|
4 |
-
|
5 |
-
sys.path.append(os.getcwd())
|
6 |
-
|
7 |
-
from .commons import fused_add_tanh_sigmoid_multiply
|
8 |
-
|
9 |
-
class WaveNet(torch.nn.Module):
|
10 |
-
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
|
11 |
-
super(WaveNet, self).__init__()
|
12 |
-
assert kernel_size % 2 == 1
|
13 |
-
self.hidden_channels = hidden_channels
|
14 |
-
self.kernel_size = (kernel_size,)
|
15 |
-
self.dilation_rate = dilation_rate
|
16 |
-
self.n_layers = n_layers
|
17 |
-
self.gin_channels = gin_channels
|
18 |
-
self.p_dropout = p_dropout
|
19 |
-
self.in_layers = torch.nn.ModuleList()
|
20 |
-
self.res_skip_layers = torch.nn.ModuleList()
|
21 |
-
self.drop = torch.nn.Dropout(p_dropout)
|
22 |
-
if gin_channels != 0: self.cond_layer = torch.nn.utils.parametrizations.weight_norm(torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1), name="weight")
|
23 |
-
dilations = [dilation_rate ** i for i in range(n_layers)]
|
24 |
-
paddings = [(kernel_size * d - d) // 2 for d in dilations]
|
25 |
-
|
26 |
-
for i in range(n_layers):
|
27 |
-
in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilations[i], padding=paddings[i])
|
28 |
-
in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
|
29 |
-
self.in_layers.append(in_layer)
|
30 |
-
res_skip_channels = (hidden_channels if i == n_layers - 1 else 2 * hidden_channels)
|
31 |
-
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
32 |
-
res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
|
33 |
-
self.res_skip_layers.append(res_skip_layer)
|
34 |
-
|
35 |
-
def forward(self, x, x_mask, g=None):
|
36 |
-
output = x.clone().zero_()
|
37 |
-
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
38 |
-
|
39 |
-
if g is not None: g = self.cond_layer(g)
|
40 |
-
|
41 |
-
for i in range(self.n_layers):
|
42 |
-
x_in = self.in_layers[i](x)
|
43 |
-
g_l = (g[:, i * 2 * self.hidden_channels : (i + 1) * 2 * self.hidden_channels, :] if g is not None else 0)
|
44 |
-
res_skip_acts = self.res_skip_layers[i](self.drop(fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)))
|
45 |
-
|
46 |
-
if i < self.n_layers - 1:
|
47 |
-
x = (x + (res_skip_acts[:, : self.hidden_channels, :])) * x_mask
|
48 |
-
output = output + res_skip_acts[:, self.hidden_channels :, :]
|
49 |
-
else: output = output + res_skip_acts
|
50 |
-
|
51 |
-
return output * x_mask
|
52 |
-
|
53 |
-
def remove_weight_norm(self):
|
54 |
-
if self.gin_channels != 0: torch.nn.utils.remove_weight_norm(self.cond_layer)
|
55 |
-
|
56 |
-
for l in self.in_layers:
|
57 |
-
torch.nn.utils.remove_weight_norm(l)
|
58 |
-
|
59 |
-
for l in self.res_skip_layers:
|
60 |
-
torch.nn.utils.remove_weight_norm(l)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/mrf_hifigan.py
DELETED
@@ -1,150 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
import torch
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
-
import torch.nn as nn
|
6 |
-
import torch.nn.functional as F
|
7 |
-
|
8 |
-
from torch.nn.utils import remove_weight_norm
|
9 |
-
from torch.utils.checkpoint import checkpoint
|
10 |
-
from torch.nn.utils.parametrizations import weight_norm
|
11 |
-
|
12 |
-
LRELU_SLOPE = 0.1
|
13 |
-
|
14 |
-
class MRFLayer(nn.Module):
|
15 |
-
def __init__(self, channels, kernel_size, dilation):
|
16 |
-
super().__init__()
|
17 |
-
self.conv1 = weight_norm(nn.Conv1d(channels, channels, kernel_size, padding=(kernel_size * dilation - dilation) // 2, dilation=dilation))
|
18 |
-
self.conv2 = weight_norm(nn.Conv1d(channels, channels, kernel_size, padding=kernel_size // 2, dilation=1))
|
19 |
-
|
20 |
-
def forward(self, x):
|
21 |
-
return x + self.conv2(F.leaky_relu(self.conv1(F.leaky_relu(x, LRELU_SLOPE)), LRELU_SLOPE))
|
22 |
-
|
23 |
-
def remove_weight_norm(self):
|
24 |
-
remove_weight_norm(self.conv1)
|
25 |
-
remove_weight_norm(self.conv2)
|
26 |
-
|
27 |
-
class MRFBlock(nn.Module):
|
28 |
-
def __init__(self, channels, kernel_size, dilations):
|
29 |
-
super().__init__()
|
30 |
-
self.layers = nn.ModuleList()
|
31 |
-
|
32 |
-
for dilation in dilations:
|
33 |
-
self.layers.append(MRFLayer(channels, kernel_size, dilation))
|
34 |
-
|
35 |
-
def forward(self, x):
|
36 |
-
for layer in self.layers:
|
37 |
-
x = layer(x)
|
38 |
-
|
39 |
-
return x
|
40 |
-
|
41 |
-
def remove_weight_norm(self):
|
42 |
-
for layer in self.layers:
|
43 |
-
layer.remove_weight_norm()
|
44 |
-
|
45 |
-
class SineGenerator(nn.Module):
|
46 |
-
def __init__(self, samp_rate, harmonic_num = 0, sine_amp = 0.1, noise_std = 0.003, voiced_threshold = 0):
|
47 |
-
super(SineGenerator, self).__init__()
|
48 |
-
self.sine_amp = sine_amp
|
49 |
-
self.noise_std = noise_std
|
50 |
-
self.harmonic_num = harmonic_num
|
51 |
-
self.dim = self.harmonic_num + 1
|
52 |
-
self.sampling_rate = samp_rate
|
53 |
-
self.voiced_threshold = voiced_threshold
|
54 |
-
|
55 |
-
def _f02uv(self, f0):
|
56 |
-
return torch.ones_like(f0) * (f0 > self.voiced_threshold)
|
57 |
-
|
58 |
-
def _f02sine(self, f0_values):
|
59 |
-
rad_values = (f0_values / self.sampling_rate) % 1
|
60 |
-
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], dtype=f0_values.dtype, device=f0_values.device)
|
61 |
-
rand_ini[:, 0] = 0
|
62 |
-
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
63 |
-
tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
64 |
-
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
|
65 |
-
cumsum_shift = torch.zeros_like(rad_values)
|
66 |
-
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
67 |
-
|
68 |
-
return torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
|
69 |
-
|
70 |
-
def forward(self, f0):
|
71 |
-
with torch.no_grad():
|
72 |
-
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, dtype=f0.dtype, device=f0.device)
|
73 |
-
f0_buf[:, :, 0] = f0[:, :, 0]
|
74 |
-
|
75 |
-
for idx in np.arange(self.harmonic_num):
|
76 |
-
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
|
77 |
-
|
78 |
-
sine_waves = self._f02sine(f0_buf) * self.sine_amp
|
79 |
-
uv = self._f02uv(f0)
|
80 |
-
sine_waves = sine_waves * uv + ((uv * self.noise_std + (1 - uv) * self.sine_amp / 3) * torch.randn_like(sine_waves))
|
81 |
-
|
82 |
-
return sine_waves
|
83 |
-
|
84 |
-
class SourceModuleHnNSF(nn.Module):
|
85 |
-
def __init__(self, sampling_rate, harmonic_num = 0, sine_amp = 0.1, add_noise_std = 0.003, voiced_threshold = 0):
|
86 |
-
super(SourceModuleHnNSF, self).__init__()
|
87 |
-
self.sine_amp = sine_amp
|
88 |
-
self.noise_std = add_noise_std
|
89 |
-
self.l_sin_gen = SineGenerator(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshold)
|
90 |
-
self.l_linear = nn.Linear(harmonic_num + 1, 1)
|
91 |
-
self.l_tanh = nn.Tanh()
|
92 |
-
|
93 |
-
def forward(self, x):
|
94 |
-
return self.l_tanh(self.l_linear(self.l_sin_gen(x).to(dtype=self.l_linear.weight.dtype)))
|
95 |
-
|
96 |
-
class HiFiGANMRFGenerator(nn.Module):
|
97 |
-
def __init__(self, in_channel, upsample_initial_channel, upsample_rates, upsample_kernel_sizes, resblock_kernel_sizes, resblock_dilations, gin_channels, sample_rate, harmonic_num, checkpointing = False):
|
98 |
-
super().__init__()
|
99 |
-
self.num_kernels = len(resblock_kernel_sizes)
|
100 |
-
self.checkpointing = checkpointing
|
101 |
-
self.f0_upsample = nn.Upsample(scale_factor=np.prod(upsample_rates))
|
102 |
-
self.m_source = SourceModuleHnNSF(sample_rate, harmonic_num)
|
103 |
-
self.conv_pre = weight_norm(nn.Conv1d(in_channel, upsample_initial_channel, kernel_size=7, stride=1, padding=3))
|
104 |
-
self.upsamples = nn.ModuleList()
|
105 |
-
self.noise_convs = nn.ModuleList()
|
106 |
-
stride_f0s = [math.prod(upsample_rates[i + 1 :]) if i + 1 < len(upsample_rates) else 1 for i in range(len(upsample_rates))]
|
107 |
-
|
108 |
-
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
109 |
-
self.upsamples.append(weight_norm(nn.ConvTranspose1d(upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), kernel_size=k, stride=u, padding=((k - u) // 2) if u % 2 == 0 else (u // 2 + u % 2), output_padding=u % 2)))
|
110 |
-
stride = stride_f0s[i]
|
111 |
-
kernel = 1 if stride == 1 else stride * 2 - stride % 2
|
112 |
-
self.noise_convs.append(nn.Conv1d(1, upsample_initial_channel // (2 ** (i + 1)), kernel_size=kernel, stride=stride, padding=0 if stride == 1 else (kernel - stride) // 2))
|
113 |
-
|
114 |
-
self.mrfs = nn.ModuleList()
|
115 |
-
for i in range(len(self.upsamples)):
|
116 |
-
channel = upsample_initial_channel // (2 ** (i + 1))
|
117 |
-
self.mrfs.append(nn.ModuleList([MRFBlock(channel, kernel_size=k, dilations=d) for k, d in zip(resblock_kernel_sizes, resblock_dilations)]))
|
118 |
-
|
119 |
-
self.conv_post = weight_norm(nn.Conv1d(channel, 1, kernel_size=7, stride=1, padding=3))
|
120 |
-
if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
121 |
-
|
122 |
-
def forward(self, x, f0, g = None):
|
123 |
-
har_source = self.m_source(self.f0_upsample(f0[:, None, :]).transpose(-1, -2)).transpose(-1, -2)
|
124 |
-
x = self.conv_pre(x)
|
125 |
-
if g is not None: x += self.cond(g)
|
126 |
-
|
127 |
-
for ups, mrf, noise_conv in zip(self.upsamples, self.mrfs, self.noise_convs):
|
128 |
-
x = F.leaky_relu(x, LRELU_SLOPE)
|
129 |
-
|
130 |
-
if self.training and self.checkpointing:
|
131 |
-
x = checkpoint(ups, x, use_reentrant=False) + noise_conv(har_source)
|
132 |
-
xs = sum([checkpoint(layer, x, use_reentrant=False) for layer in mrf])
|
133 |
-
else:
|
134 |
-
x = ups(x) + noise_conv(har_source)
|
135 |
-
xs = sum([layer(x) for layer in mrf])
|
136 |
-
|
137 |
-
x = xs / self.num_kernels
|
138 |
-
|
139 |
-
return torch.tanh(self.conv_post(F.leaky_relu(x)))
|
140 |
-
|
141 |
-
def remove_weight_norm(self):
|
142 |
-
remove_weight_norm(self.conv_pre)
|
143 |
-
|
144 |
-
for up in self.upsamples:
|
145 |
-
remove_weight_norm(up)
|
146 |
-
|
147 |
-
for mrf in self.mrfs:
|
148 |
-
mrf.remove_weight_norm()
|
149 |
-
|
150 |
-
remove_weight_norm(self.conv_post)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/onnx_export.py
DELETED
@@ -1,50 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import io
|
3 |
-
import sys
|
4 |
-
import onnx
|
5 |
-
import json
|
6 |
-
import torch
|
7 |
-
import onnxsim
|
8 |
-
import warnings
|
9 |
-
|
10 |
-
sys.path.append(os.getcwd())
|
11 |
-
|
12 |
-
from main.library.algorithm.synthesizers import SynthesizerONNX
|
13 |
-
|
14 |
-
warnings.filterwarnings("ignore")
|
15 |
-
|
16 |
-
def onnx_exporter(input_path, output_path, is_half=False, device="cpu"):
|
17 |
-
cpt = (torch.load(input_path, map_location="cpu") if os.path.isfile(input_path) else None)
|
18 |
-
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
|
19 |
-
|
20 |
-
model_name, model_author, epochs, steps, version, f0, model_hash, vocoder, creation_date = cpt.get("model_name", None), cpt.get("author", None), cpt.get("epoch", None), cpt.get("step", None), cpt.get("version", "v1"), cpt.get("f0", 1), cpt.get("model_hash", None), cpt.get("vocoder", "Default"), cpt.get("creation_date", None)
|
21 |
-
text_enc_hidden_dim = 768 if version == "v2" else 256
|
22 |
-
tgt_sr = cpt["config"][-1]
|
23 |
-
|
24 |
-
net_g = SynthesizerONNX(*cpt["config"], use_f0=f0, text_enc_hidden_dim=text_enc_hidden_dim, vocoder=vocoder, checkpointing=False)
|
25 |
-
net_g.load_state_dict(cpt["weight"], strict=False)
|
26 |
-
net_g.eval().to(device)
|
27 |
-
net_g = (net_g.half() if is_half else net_g.float())
|
28 |
-
|
29 |
-
phone = torch.rand(1, 200, text_enc_hidden_dim).to(device)
|
30 |
-
phone_length = torch.tensor([200]).long().to(device)
|
31 |
-
ds = torch.LongTensor([0]).to(device)
|
32 |
-
rnd = torch.rand(1, 192, 200).to(device)
|
33 |
-
|
34 |
-
if f0:
|
35 |
-
args = (phone, phone_length, ds, rnd, torch.randint(size=(1, 200), low=5, high=255).to(device), torch.rand(1, 200).to(device))
|
36 |
-
input_names = ["phone", "phone_lengths", "ds", "rnd", "pitch", "pitchf"]
|
37 |
-
dynamic_axes = {"phone": [1], "rnd": [2], "pitch": [1], "pitchf": [1]}
|
38 |
-
else:
|
39 |
-
args = (phone, phone_length, ds, rnd)
|
40 |
-
input_names = ["phone", "phone_lengths", "ds", "rnd"]
|
41 |
-
dynamic_axes = {"phone": [1], "rnd": [2]}
|
42 |
-
|
43 |
-
with io.BytesIO() as model:
|
44 |
-
torch.onnx.export(net_g, args, model, do_constant_folding=True, opset_version=17, verbose=False, input_names=input_names, output_names=["audio"], dynamic_axes=dynamic_axes)
|
45 |
-
|
46 |
-
model, _ = onnxsim.simplify(onnx.load_model_from_string(model.getvalue()))
|
47 |
-
model.metadata_props.append(onnx.StringStringEntryProto(key="model_info", value=json.dumps({"model_name": model_name, "author": model_author, "epoch": epochs, "step": steps, "version": version, "sr": tgt_sr, "f0": f0, "model_hash": model_hash, "creation_date": creation_date, "vocoder": vocoder, "text_enc_hidden_dim": text_enc_hidden_dim})))
|
48 |
-
|
49 |
-
onnx.save(model, output_path)
|
50 |
-
return output_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/refinegan.py
DELETED
@@ -1,170 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import math
|
4 |
-
import torch
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import torch.nn as nn
|
8 |
-
import torch.nn.functional as F
|
9 |
-
|
10 |
-
from torch.utils.checkpoint import checkpoint
|
11 |
-
from torch.nn.utils import remove_weight_norm
|
12 |
-
from torch.nn.utils.parametrizations import weight_norm
|
13 |
-
|
14 |
-
sys.path.append(os.getcwd())
|
15 |
-
|
16 |
-
from main.library.algorithm.commons import init_weights, get_padding
|
17 |
-
|
18 |
-
|
19 |
-
class ResBlock(nn.Module):
|
20 |
-
def __init__(self, channels, kernel_size = 7, dilation = (1, 3, 5), leaky_relu_slope = 0.2):
|
21 |
-
super().__init__()
|
22 |
-
self.leaky_relu_slope = leaky_relu_slope
|
23 |
-
self.convs1 = nn.ModuleList([weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=d, padding=get_padding(kernel_size, d))) for d in dilation])
|
24 |
-
self.convs1.apply(init_weights)
|
25 |
-
self.convs2 = nn.ModuleList([weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1, padding=get_padding(kernel_size, 1))) for _ in dilation])
|
26 |
-
self.convs2.apply(init_weights)
|
27 |
-
|
28 |
-
def forward(self, x):
|
29 |
-
for c1, c2 in zip(self.convs1, self.convs2):
|
30 |
-
x = c2(F.leaky_relu(c1(F.leaky_relu(x, self.leaky_relu_slope)), self.leaky_relu_slope)) + x
|
31 |
-
|
32 |
-
return x
|
33 |
-
|
34 |
-
def remove_weight_norm(self):
|
35 |
-
for c1, c2 in zip(self.convs1, self.convs2):
|
36 |
-
remove_weight_norm(c1)
|
37 |
-
remove_weight_norm(c2)
|
38 |
-
|
39 |
-
class AdaIN(nn.Module):
|
40 |
-
def __init__(self, *, channels, leaky_relu_slope = 0.2):
|
41 |
-
super().__init__()
|
42 |
-
self.weight = nn.Parameter(torch.ones(channels))
|
43 |
-
self.activation = nn.LeakyReLU(leaky_relu_slope)
|
44 |
-
|
45 |
-
def forward(self, x):
|
46 |
-
return self.activation(x + (torch.randn_like(x) * self.weight[None, :, None]))
|
47 |
-
|
48 |
-
class ParallelResBlock(nn.Module):
|
49 |
-
def __init__(self, *, in_channels, out_channels, kernel_sizes = (3, 7, 11), dilation = (1, 3, 5), leaky_relu_slope = 0.2):
|
50 |
-
super().__init__()
|
51 |
-
self.in_channels = in_channels
|
52 |
-
self.out_channels = out_channels
|
53 |
-
self.input_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=7, stride=1, padding=3)
|
54 |
-
self.input_conv.apply(init_weights)
|
55 |
-
self.blocks = nn.ModuleList([nn.Sequential(AdaIN(channels=out_channels), ResBlock(out_channels, kernel_size=kernel_size, dilation=dilation, leaky_relu_slope=leaky_relu_slope), AdaIN(channels=out_channels)) for kernel_size in kernel_sizes])
|
56 |
-
|
57 |
-
def forward(self, x):
|
58 |
-
x = self.input_conv(x)
|
59 |
-
return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
|
60 |
-
|
61 |
-
def remove_weight_norm(self):
|
62 |
-
remove_weight_norm(self.input_conv)
|
63 |
-
for block in self.blocks:
|
64 |
-
block[1].remove_weight_norm()
|
65 |
-
|
66 |
-
class SineGenerator(nn.Module):
|
67 |
-
def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0):
|
68 |
-
super(SineGenerator, self).__init__()
|
69 |
-
self.sine_amp = sine_amp
|
70 |
-
self.noise_std = noise_std
|
71 |
-
self.harmonic_num = harmonic_num
|
72 |
-
self.dim = self.harmonic_num + 1
|
73 |
-
self.sampling_rate = samp_rate
|
74 |
-
self.voiced_threshold = voiced_threshold
|
75 |
-
self.merge = nn.Sequential(nn.Linear(self.dim, 1, bias=False), nn.Tanh())
|
76 |
-
|
77 |
-
def _f02uv(self, f0):
|
78 |
-
return torch.ones_like(f0) * (f0 > self.voiced_threshold)
|
79 |
-
|
80 |
-
def _f02sine(self, f0_values):
|
81 |
-
rad_values = (f0_values / self.sampling_rate) % 1
|
82 |
-
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], dtype=f0_values.dtype, device=f0_values.device)
|
83 |
-
|
84 |
-
rand_ini[:, 0] = 0
|
85 |
-
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
86 |
-
|
87 |
-
tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
88 |
-
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
|
89 |
-
|
90 |
-
cumsum_shift = torch.zeros_like(rad_values)
|
91 |
-
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
92 |
-
|
93 |
-
return torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
|
94 |
-
|
95 |
-
def forward(self, f0):
|
96 |
-
with torch.no_grad():
|
97 |
-
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, dtype=f0.dtype, device=f0.device)
|
98 |
-
f0_buf[:, :, 0] = f0[:, :, 0]
|
99 |
-
|
100 |
-
for idx in np.arange(self.harmonic_num):
|
101 |
-
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
|
102 |
-
|
103 |
-
sine_waves = self._f02sine(f0_buf) * self.sine_amp
|
104 |
-
uv = self._f02uv(f0)
|
105 |
-
sine_waves = sine_waves * uv + ((uv * self.noise_std + (1 - uv) * self.sine_amp / 3) * torch.randn_like(sine_waves))
|
106 |
-
|
107 |
-
return self.merge(sine_waves)
|
108 |
-
|
109 |
-
class RefineGANGenerator(nn.Module):
|
110 |
-
def __init__(self, *, sample_rate = 44100, upsample_rates = (8, 8, 2, 2), leaky_relu_slope = 0.2, num_mels = 128, gin_channels = 256, checkpointing = False, upsample_initial_channel = 512):
|
111 |
-
super().__init__()
|
112 |
-
self.upsample_rates = upsample_rates
|
113 |
-
self.checkpointing = checkpointing
|
114 |
-
self.leaky_relu_slope = leaky_relu_slope
|
115 |
-
self.upp = np.prod(upsample_rates)
|
116 |
-
self.m_source = SineGenerator(sample_rate)
|
117 |
-
self.pre_conv = weight_norm(nn.Conv1d(1, upsample_initial_channel // 2, 7, 1, padding=3))
|
118 |
-
stride_f0s = [math.prod(upsample_rates[i + 1 :]) if i + 1 < len(upsample_rates) else 1 for i in range(len(upsample_rates))]
|
119 |
-
|
120 |
-
channels = upsample_initial_channel
|
121 |
-
self.downsample_blocks = nn.ModuleList([])
|
122 |
-
|
123 |
-
for i, _ in enumerate(upsample_rates):
|
124 |
-
stride = stride_f0s[i]
|
125 |
-
kernel = 1 if stride == 1 else stride * 2 - stride % 2
|
126 |
-
|
127 |
-
self.downsample_blocks.append(weight_norm(nn.Conv1d(1, channels // 2 ** (i + 2), kernel, stride, padding=0 if stride == 1 else (kernel - stride) // 2)))
|
128 |
-
|
129 |
-
self.mel_conv = weight_norm(nn.Conv1d(num_mels, channels // 2, 7, 1, padding=3))
|
130 |
-
self.mel_conv.apply(init_weights)
|
131 |
-
|
132 |
-
if gin_channels != 0: self.cond = nn.Conv1d(256, channels // 2, 1)
|
133 |
-
|
134 |
-
self.upsample_blocks = nn.ModuleList([])
|
135 |
-
self.upsample_conv_blocks = nn.ModuleList([])
|
136 |
-
|
137 |
-
for rate in upsample_rates:
|
138 |
-
new_channels = channels // 2
|
139 |
-
self.upsample_blocks.append(nn.Upsample(scale_factor=rate, mode="linear"))
|
140 |
-
self.upsample_conv_blocks.append(ParallelResBlock(in_channels=channels + channels // 4, out_channels=new_channels, kernel_sizes=(3, 7, 11), dilation=(1, 3, 5), leaky_relu_slope=leaky_relu_slope))
|
141 |
-
channels = new_channels
|
142 |
-
|
143 |
-
self.conv_post = weight_norm(nn.Conv1d(channels, 1, 7, 1, padding=3, bias=False))
|
144 |
-
self.conv_post.apply(init_weights)
|
145 |
-
|
146 |
-
def forward(self, mel, f0, g = None):
|
147 |
-
har_source = self.m_source(F.interpolate(f0.unsqueeze(1), size=mel.shape[-1] * self.upp, mode="linear").transpose(1, 2)).transpose(1, 2)
|
148 |
-
x = F.interpolate(self.pre_conv(har_source), size=mel.shape[-1], mode="linear")
|
149 |
-
|
150 |
-
mel = self.mel_conv(mel)
|
151 |
-
if g is not None: mel += self.cond(g)
|
152 |
-
|
153 |
-
x = torch.cat([mel, x], dim=1)
|
154 |
-
|
155 |
-
for ups, res, down in zip(self.upsample_blocks, self.upsample_conv_blocks, self.downsample_blocks):
|
156 |
-
x = F.leaky_relu(x, self.leaky_relu_slope)
|
157 |
-
x = checkpoint(res, torch.cat([checkpoint(ups, x, use_reentrant=False), down(har_source)], dim=1), use_reentrant=False) if self.training and self.checkpointing else res(torch.cat([ups(x), down(har_source)], dim=1))
|
158 |
-
|
159 |
-
return torch.tanh(self.conv_post(F.leaky_relu(x, self.leaky_relu_slope)))
|
160 |
-
|
161 |
-
def remove_weight_norm(self):
|
162 |
-
remove_weight_norm(self.pre_conv)
|
163 |
-
remove_weight_norm(self.mel_conv)
|
164 |
-
remove_weight_norm(self.conv_post)
|
165 |
-
|
166 |
-
for block in self.downsample_blocks:
|
167 |
-
block.remove_weight_norm()
|
168 |
-
|
169 |
-
for block in self.upsample_conv_blocks:
|
170 |
-
block.remove_weight_norm()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/residuals.py
DELETED
@@ -1,140 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import torch
|
4 |
-
|
5 |
-
from torch.nn.utils import remove_weight_norm
|
6 |
-
from torch.nn.utils.parametrizations import weight_norm
|
7 |
-
|
8 |
-
sys.path.append(os.getcwd())
|
9 |
-
|
10 |
-
from .modules import WaveNet
|
11 |
-
from .commons import get_padding, init_weights
|
12 |
-
|
13 |
-
|
14 |
-
LRELU_SLOPE = 0.1
|
15 |
-
|
16 |
-
def create_conv1d_layer(channels, kernel_size, dilation):
|
17 |
-
return weight_norm(torch.nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation, padding=get_padding(kernel_size, dilation)))
|
18 |
-
|
19 |
-
def apply_mask(tensor, mask):
|
20 |
-
return tensor * mask if mask is not None else tensor
|
21 |
-
|
22 |
-
class ResBlockBase(torch.nn.Module):
|
23 |
-
def __init__(self, channels, kernel_size, dilations):
|
24 |
-
super(ResBlockBase, self).__init__()
|
25 |
-
|
26 |
-
self.convs1 = torch.nn.ModuleList([create_conv1d_layer(channels, kernel_size, d) for d in dilations])
|
27 |
-
self.convs1.apply(init_weights)
|
28 |
-
|
29 |
-
self.convs2 = torch.nn.ModuleList([create_conv1d_layer(channels, kernel_size, 1) for _ in dilations])
|
30 |
-
self.convs2.apply(init_weights)
|
31 |
-
|
32 |
-
def forward(self, x, x_mask=None):
|
33 |
-
for c1, c2 in zip(self.convs1, self.convs2):
|
34 |
-
x = c2(apply_mask(torch.nn.functional.leaky_relu(c1(apply_mask(torch.nn.functional.leaky_relu(x, LRELU_SLOPE), x_mask)), LRELU_SLOPE), x_mask)) + x
|
35 |
-
|
36 |
-
return apply_mask(x, x_mask)
|
37 |
-
|
38 |
-
def remove_weight_norm(self):
|
39 |
-
for conv in self.convs1 + self.convs2:
|
40 |
-
remove_weight_norm(conv)
|
41 |
-
|
42 |
-
class ResBlock(ResBlockBase):
|
43 |
-
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
44 |
-
super(ResBlock, self).__init__(channels, kernel_size, dilation)
|
45 |
-
|
46 |
-
class Log(torch.nn.Module):
|
47 |
-
def forward(self, x, x_mask, reverse=False, **kwargs):
|
48 |
-
if not reverse:
|
49 |
-
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
|
50 |
-
return y, torch.sum(-y, [1, 2])
|
51 |
-
else: return torch.exp(x) * x_mask
|
52 |
-
|
53 |
-
class Flip(torch.nn.Module):
|
54 |
-
def forward(self, x, *args, reverse=False, **kwargs):
|
55 |
-
x = torch.flip(x, [1])
|
56 |
-
|
57 |
-
if not reverse: return x, torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
58 |
-
else: return x
|
59 |
-
|
60 |
-
class ElementwiseAffine(torch.nn.Module):
|
61 |
-
def __init__(self, channels):
|
62 |
-
super().__init__()
|
63 |
-
self.channels = channels
|
64 |
-
self.m = torch.nn.Parameter(torch.zeros(channels, 1))
|
65 |
-
self.logs = torch.nn.Parameter(torch.zeros(channels, 1))
|
66 |
-
|
67 |
-
def forward(self, x, x_mask, reverse=False, **kwargs):
|
68 |
-
if not reverse: return ((self.m + torch.exp(self.logs) * x) * x_mask), torch.sum(self.logs * x_mask, [1, 2])
|
69 |
-
else: return (x - self.m) * torch.exp(-self.logs) * x_mask
|
70 |
-
|
71 |
-
class ResidualCouplingBlock(torch.nn.Module):
|
72 |
-
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0):
|
73 |
-
super(ResidualCouplingBlock, self).__init__()
|
74 |
-
self.channels = channels
|
75 |
-
self.hidden_channels = hidden_channels
|
76 |
-
self.kernel_size = kernel_size
|
77 |
-
self.dilation_rate = dilation_rate
|
78 |
-
self.n_layers = n_layers
|
79 |
-
self.n_flows = n_flows
|
80 |
-
self.gin_channels = gin_channels
|
81 |
-
self.flows = torch.nn.ModuleList()
|
82 |
-
|
83 |
-
for _ in range(n_flows):
|
84 |
-
self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
|
85 |
-
self.flows.append(Flip())
|
86 |
-
|
87 |
-
def forward(self, x, x_mask, g = None, reverse = False):
|
88 |
-
if not reverse:
|
89 |
-
for flow in self.flows:
|
90 |
-
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
91 |
-
else:
|
92 |
-
for flow in reversed(self.flows):
|
93 |
-
x = flow.forward(x, x_mask, g=g, reverse=reverse)
|
94 |
-
|
95 |
-
return x
|
96 |
-
|
97 |
-
def remove_weight_norm(self):
|
98 |
-
for i in range(self.n_flows):
|
99 |
-
self.flows[i * 2].remove_weight_norm()
|
100 |
-
|
101 |
-
def __prepare_scriptable__(self):
|
102 |
-
for i in range(self.n_flows):
|
103 |
-
for hook in self.flows[i * 2]._forward_pre_hooks.values():
|
104 |
-
if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(self.flows[i * 2])
|
105 |
-
|
106 |
-
return self
|
107 |
-
|
108 |
-
class ResidualCouplingLayer(torch.nn.Module):
|
109 |
-
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=0, mean_only=False):
|
110 |
-
assert channels % 2 == 0, "Channels/2"
|
111 |
-
super().__init__()
|
112 |
-
self.channels = channels
|
113 |
-
self.hidden_channels = hidden_channels
|
114 |
-
self.kernel_size = kernel_size
|
115 |
-
self.dilation_rate = dilation_rate
|
116 |
-
self.n_layers = n_layers
|
117 |
-
self.half_channels = channels // 2
|
118 |
-
self.mean_only = mean_only
|
119 |
-
|
120 |
-
self.pre = torch.nn.Conv1d(self.half_channels, hidden_channels, 1)
|
121 |
-
self.enc = WaveNet(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
|
122 |
-
self.post = torch.nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
123 |
-
|
124 |
-
self.post.weight.data.zero_()
|
125 |
-
self.post.bias.data.zero_()
|
126 |
-
|
127 |
-
def forward(self, x, x_mask, g=None, reverse=False):
|
128 |
-
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
129 |
-
stats = self.post(self.enc((self.pre(x0) * x_mask), x_mask, g=g)) * x_mask
|
130 |
-
|
131 |
-
if not self.mean_only: m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
132 |
-
else:
|
133 |
-
m = stats
|
134 |
-
logs = torch.zeros_like(m)
|
135 |
-
|
136 |
-
if not reverse: return torch.cat([x0, (m + x1 * torch.exp(logs) * x_mask)], 1), torch.sum(logs, [1, 2])
|
137 |
-
else: return torch.cat([x0, ((x1 - m) * torch.exp(-logs) * x_mask)], 1)
|
138 |
-
|
139 |
-
def remove_weight_norm(self):
|
140 |
-
self.enc.remove_weight_norm()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/separator.py
DELETED
@@ -1,320 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import time
|
4 |
-
import yaml
|
5 |
-
import torch
|
6 |
-
import codecs
|
7 |
-
import hashlib
|
8 |
-
import logging
|
9 |
-
import platform
|
10 |
-
import warnings
|
11 |
-
import requests
|
12 |
-
import onnxruntime
|
13 |
-
|
14 |
-
from importlib import metadata, import_module
|
15 |
-
|
16 |
-
now_dir = os.getcwd()
|
17 |
-
sys.path.append(now_dir)
|
18 |
-
|
19 |
-
from main.configs.config import Config
|
20 |
-
from main.tools.huggingface import HF_download_file
|
21 |
-
|
22 |
-
translations = Config().translations
|
23 |
-
|
24 |
-
|
25 |
-
class Separator:
|
26 |
-
def __init__(self, logger=logging.getLogger(__name__), log_level=logging.INFO, log_formatter=None, model_file_dir="assets/models/uvr5", output_dir=None, output_format="wav", output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=44100, mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False}, demucs_params={"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True}):
|
27 |
-
self.logger = logger
|
28 |
-
self.log_level = log_level
|
29 |
-
self.log_formatter = log_formatter
|
30 |
-
self.log_handler = logging.StreamHandler()
|
31 |
-
|
32 |
-
if self.log_formatter is None: self.log_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(module)s - %(message)s")
|
33 |
-
self.log_handler.setFormatter(self.log_formatter)
|
34 |
-
|
35 |
-
if not self.logger.hasHandlers(): self.logger.addHandler(self.log_handler)
|
36 |
-
if log_level > logging.DEBUG: warnings.filterwarnings("ignore")
|
37 |
-
|
38 |
-
self.logger.info(translations["separator_info"].format(output_dir=output_dir, output_format=output_format))
|
39 |
-
self.model_file_dir = model_file_dir
|
40 |
-
|
41 |
-
if output_dir is None:
|
42 |
-
output_dir = now_dir
|
43 |
-
self.logger.info(translations["output_dir_is_none"])
|
44 |
-
|
45 |
-
self.output_dir = output_dir
|
46 |
-
|
47 |
-
os.makedirs(self.model_file_dir, exist_ok=True)
|
48 |
-
os.makedirs(self.output_dir, exist_ok=True)
|
49 |
-
|
50 |
-
self.output_format = output_format
|
51 |
-
self.output_bitrate = output_bitrate
|
52 |
-
|
53 |
-
if self.output_format is None: self.output_format = "wav"
|
54 |
-
self.normalization_threshold = normalization_threshold
|
55 |
-
if normalization_threshold <= 0 or normalization_threshold > 1: raise ValueError(translations[">0or=1"])
|
56 |
-
|
57 |
-
self.output_single_stem = output_single_stem
|
58 |
-
if output_single_stem is not None: self.logger.debug(translations["output_single"].format(output_single_stem=output_single_stem))
|
59 |
-
|
60 |
-
self.invert_using_spec = invert_using_spec
|
61 |
-
if self.invert_using_spec: self.logger.debug(translations["step2"])
|
62 |
-
|
63 |
-
self.sample_rate = int(sample_rate)
|
64 |
-
self.arch_specific_params = {"MDX": mdx_params, "Demucs": demucs_params}
|
65 |
-
self.torch_device = None
|
66 |
-
self.torch_device_cpu = None
|
67 |
-
self.torch_device_mps = None
|
68 |
-
self.onnx_execution_provider = None
|
69 |
-
self.model_instance = None
|
70 |
-
self.model_is_uvr_vip = False
|
71 |
-
self.model_friendly_name = None
|
72 |
-
self.setup_accelerated_inferencing_device()
|
73 |
-
|
74 |
-
def setup_accelerated_inferencing_device(self):
|
75 |
-
system_info = self.get_system_info()
|
76 |
-
self.log_onnxruntime_packages()
|
77 |
-
self.setup_torch_device(system_info)
|
78 |
-
|
79 |
-
def get_system_info(self):
|
80 |
-
os_name = platform.system()
|
81 |
-
os_version = platform.version()
|
82 |
-
self.logger.info(f"{translations['os']}: {os_name} {os_version}")
|
83 |
-
system_info = platform.uname()
|
84 |
-
self.logger.info(translations["platform_info"].format(system_info=system_info, node=system_info.node, release=system_info.release, machine=system_info.machine, processor=system_info.processor))
|
85 |
-
python_version = platform.python_version()
|
86 |
-
self.logger.info(f"{translations['name_ver'].format(name='python')}: {python_version}")
|
87 |
-
pytorch_version = torch.__version__
|
88 |
-
self.logger.info(f"{translations['name_ver'].format(name='pytorch')}: {pytorch_version}")
|
89 |
-
|
90 |
-
return system_info
|
91 |
-
|
92 |
-
def log_onnxruntime_packages(self):
|
93 |
-
onnxruntime_gpu_package = self.get_package_distribution("onnxruntime-gpu")
|
94 |
-
onnxruntime_cpu_package = self.get_package_distribution("onnxruntime")
|
95 |
-
|
96 |
-
if onnxruntime_gpu_package is not None: self.logger.info(f"{translations['install_onnx'].format(pu='GPU')}: {onnxruntime_gpu_package.version}")
|
97 |
-
if onnxruntime_cpu_package is not None: self.logger.info(f"{translations['install_onnx'].format(pu='CPU')}: {onnxruntime_cpu_package.version}")
|
98 |
-
|
99 |
-
def setup_torch_device(self, system_info):
|
100 |
-
hardware_acceleration_enabled = False
|
101 |
-
ort_providers = onnxruntime.get_available_providers()
|
102 |
-
self.torch_device_cpu = torch.device("cpu")
|
103 |
-
|
104 |
-
if torch.cuda.is_available():
|
105 |
-
self.configure_cuda(ort_providers)
|
106 |
-
hardware_acceleration_enabled = True
|
107 |
-
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and system_info.processor == "arm":
|
108 |
-
self.configure_mps(ort_providers)
|
109 |
-
hardware_acceleration_enabled = True
|
110 |
-
|
111 |
-
if not hardware_acceleration_enabled:
|
112 |
-
self.logger.info(translations["running_in_cpu"])
|
113 |
-
self.torch_device = self.torch_device_cpu
|
114 |
-
self.onnx_execution_provider = ["CPUExecutionProvider"]
|
115 |
-
|
116 |
-
def configure_cuda(self, ort_providers):
|
117 |
-
self.logger.info(translations["running_in_cuda"])
|
118 |
-
self.torch_device = torch.device("cuda")
|
119 |
-
|
120 |
-
if "CUDAExecutionProvider" in ort_providers:
|
121 |
-
self.logger.info(translations["onnx_have"].format(have='CUDAExecutionProvider'))
|
122 |
-
self.onnx_execution_provider = ["CUDAExecutionProvider"]
|
123 |
-
else: self.logger.warning(translations["onnx_not_have"].format(have='CUDAExecutionProvider'))
|
124 |
-
|
125 |
-
def configure_mps(self, ort_providers):
|
126 |
-
self.logger.info(translations["set_torch_mps"])
|
127 |
-
self.torch_device_mps = torch.device("mps")
|
128 |
-
self.torch_device = self.torch_device_mps
|
129 |
-
|
130 |
-
if "CoreMLExecutionProvider" in ort_providers:
|
131 |
-
self.logger.info(translations["onnx_have"].format(have='CoreMLExecutionProvider'))
|
132 |
-
self.onnx_execution_provider = ["CoreMLExecutionProvider"]
|
133 |
-
else: self.logger.warning(translations["onnx_not_have"].format(have='CoreMLExecutionProvider'))
|
134 |
-
|
135 |
-
def get_package_distribution(self, package_name):
|
136 |
-
try:
|
137 |
-
return metadata.distribution(package_name)
|
138 |
-
except metadata.PackageNotFoundError:
|
139 |
-
self.logger.debug(translations["python_not_install"].format(package_name=package_name))
|
140 |
-
return None
|
141 |
-
|
142 |
-
def get_model_hash(self, model_path):
|
143 |
-
self.logger.debug(translations["hash"].format(model_path=model_path))
|
144 |
-
|
145 |
-
try:
|
146 |
-
with open(model_path, "rb") as f:
|
147 |
-
f.seek(-10000 * 1024, 2)
|
148 |
-
return hashlib.md5(f.read()).hexdigest()
|
149 |
-
except IOError as e:
|
150 |
-
self.logger.error(translations["ioerror"].format(e=e))
|
151 |
-
return hashlib.md5(open(model_path, "rb").read()).hexdigest()
|
152 |
-
|
153 |
-
def download_file_if_not_exists(self, url, output_path):
|
154 |
-
if os.path.isfile(output_path):
|
155 |
-
self.logger.debug(translations["cancel_download"].format(output_path=output_path))
|
156 |
-
return
|
157 |
-
|
158 |
-
self.logger.debug(translations["download_model"].format(url=url, output_path=output_path))
|
159 |
-
HF_download_file(url, output_path)
|
160 |
-
|
161 |
-
def print_uvr_vip_message(self):
|
162 |
-
if self.model_is_uvr_vip:
|
163 |
-
self.logger.warning(translations["vip_model"].format(model_friendly_name=self.model_friendly_name))
|
164 |
-
self.logger.warning(translations["vip_print"])
|
165 |
-
|
166 |
-
def list_supported_model_files(self):
|
167 |
-
response = requests.get(codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/enj/znva/wfba/hie_zbqryf.wfba", "rot13"))
|
168 |
-
response.raise_for_status()
|
169 |
-
model_downloads_list = response.json()
|
170 |
-
self.logger.debug(translations["load_download_json"])
|
171 |
-
|
172 |
-
return {"MDX": {**model_downloads_list["mdx_download_list"], **model_downloads_list["mdx_download_vip_list"]}, "Demucs": {key: value for key, value in model_downloads_list["demucs_download_list"].items() if key.startswith("Demucs v4")}}
|
173 |
-
|
174 |
-
def download_model_files(self, model_filename):
|
175 |
-
model_path = os.path.join(self.model_file_dir, model_filename)
|
176 |
-
supported_model_files_grouped = self.list_supported_model_files()
|
177 |
-
|
178 |
-
yaml_config_filename = None
|
179 |
-
self.logger.debug(translations["search_model"].format(model_filename=model_filename))
|
180 |
-
|
181 |
-
for model_type, model_list in supported_model_files_grouped.items():
|
182 |
-
for model_friendly_name, model_download_list in model_list.items():
|
183 |
-
self.model_is_uvr_vip = "VIP" in model_friendly_name
|
184 |
-
model_repo_url_prefix = codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/hie5_zbqryf", "rot13")
|
185 |
-
|
186 |
-
if isinstance(model_download_list, str) and model_download_list == model_filename:
|
187 |
-
self.logger.debug(translations["single_model"].format(model_friendly_name=model_friendly_name))
|
188 |
-
self.model_friendly_name = model_friendly_name
|
189 |
-
|
190 |
-
try:
|
191 |
-
self.download_file_if_not_exists(f"{model_repo_url_prefix}/MDX/{model_filename}", model_path)
|
192 |
-
except RuntimeError:
|
193 |
-
self.logger.warning(translations["not_found_model"])
|
194 |
-
self.download_file_if_not_exists(f"{model_repo_url_prefix}/Demucs/{model_filename}", model_path)
|
195 |
-
|
196 |
-
self.print_uvr_vip_message()
|
197 |
-
self.logger.debug(translations["single_model_path"].format(model_path=model_path))
|
198 |
-
|
199 |
-
return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
|
200 |
-
elif isinstance(model_download_list, dict):
|
201 |
-
this_model_matches_input_filename = False
|
202 |
-
|
203 |
-
for file_name, file_url in model_download_list.items():
|
204 |
-
if file_name == model_filename or file_url == model_filename:
|
205 |
-
self.logger.debug(translations["find_model"].format(model_filename=model_filename, model_friendly_name=model_friendly_name))
|
206 |
-
this_model_matches_input_filename = True
|
207 |
-
|
208 |
-
if this_model_matches_input_filename:
|
209 |
-
self.logger.debug(translations["find_models"].format(model_friendly_name=model_friendly_name))
|
210 |
-
self.model_friendly_name = model_friendly_name
|
211 |
-
self.print_uvr_vip_message()
|
212 |
-
|
213 |
-
for config_key, config_value in model_download_list.items():
|
214 |
-
self.logger.debug(f"{translations['find_path']}: {config_key} -> {config_value}")
|
215 |
-
|
216 |
-
if config_value.startswith("http"): self.download_file_if_not_exists(config_value, os.path.join(self.model_file_dir, config_key))
|
217 |
-
elif config_key.endswith(".ckpt"):
|
218 |
-
try:
|
219 |
-
self.download_file_if_not_exists(f"{model_repo_url_prefix}/Demucs/{config_key}", os.path.join(self.model_file_dir, config_key))
|
220 |
-
except RuntimeError:
|
221 |
-
self.logger.warning(translations["not_found_model_warehouse"])
|
222 |
-
|
223 |
-
if model_filename.endswith(".yaml"):
|
224 |
-
self.logger.warning(translations["yaml_warning"].format(model_filename=model_filename))
|
225 |
-
self.logger.warning(translations["yaml_warning_2"].format(config_key=config_key))
|
226 |
-
self.logger.warning(translations["yaml_warning_3"])
|
227 |
-
|
228 |
-
model_filename = config_key
|
229 |
-
model_path = os.path.join(self.model_file_dir, f"{model_filename}")
|
230 |
-
|
231 |
-
yaml_config_filename = config_value
|
232 |
-
yaml_config_filepath = os.path.join(self.model_file_dir, yaml_config_filename)
|
233 |
-
|
234 |
-
try:
|
235 |
-
self.download_file_if_not_exists(f"{model_repo_url_prefix}/mdx_c_configs/{yaml_config_filename}", yaml_config_filepath)
|
236 |
-
except RuntimeError:
|
237 |
-
self.logger.debug(translations["yaml_debug"])
|
238 |
-
else: self.download_file_if_not_exists(f"{model_repo_url_prefix}/Demucs/{config_value}", os.path.join(self.model_file_dir, config_value))
|
239 |
-
|
240 |
-
self.logger.debug(translations["download_model_friendly"].format(model_friendly_name=model_friendly_name, model_path=model_path))
|
241 |
-
return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
|
242 |
-
|
243 |
-
raise ValueError(translations["not_found_model_2"].format(model_filename=model_filename))
|
244 |
-
|
245 |
-
def load_model_data_from_yaml(self, yaml_config_filename):
|
246 |
-
model_data_yaml_filepath = os.path.join(self.model_file_dir, yaml_config_filename) if not os.path.exists(yaml_config_filename) else yaml_config_filename
|
247 |
-
self.logger.debug(translations["load_yaml"].format(model_data_yaml_filepath=model_data_yaml_filepath))
|
248 |
-
|
249 |
-
model_data = yaml.load(open(model_data_yaml_filepath, encoding="utf-8"), Loader=yaml.FullLoader)
|
250 |
-
self.logger.debug(translations["load_yaml_2"].format(model_data=model_data))
|
251 |
-
|
252 |
-
if "roformer" in model_data_yaml_filepath: model_data["is_roformer"] = True
|
253 |
-
return model_data
|
254 |
-
|
255 |
-
def load_model_data_using_hash(self, model_path):
|
256 |
-
self.logger.debug(translations["hash_md5"])
|
257 |
-
model_hash = self.get_model_hash(model_path)
|
258 |
-
|
259 |
-
self.logger.debug(translations["model_hash"].format(model_path=model_path, model_hash=model_hash))
|
260 |
-
mdx_model_data_path = codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/enj/znva/wfba/zbqry_qngn.wfba", "rot13")
|
261 |
-
self.logger.debug(translations["mdx_data"].format(mdx_model_data_path=mdx_model_data_path))
|
262 |
-
|
263 |
-
response = requests.get(mdx_model_data_path)
|
264 |
-
response.raise_for_status()
|
265 |
-
|
266 |
-
mdx_model_data_object = response.json()
|
267 |
-
self.logger.debug(translations["load_mdx"])
|
268 |
-
|
269 |
-
if model_hash in mdx_model_data_object: model_data = mdx_model_data_object[model_hash]
|
270 |
-
else: raise ValueError(translations["model_not_support"].format(model_hash=model_hash))
|
271 |
-
|
272 |
-
self.logger.debug(translations["uvr_json"].format(model_hash=model_hash, model_data=model_data))
|
273 |
-
return model_data
|
274 |
-
|
275 |
-
def load_model(self, model_filename):
|
276 |
-
self.logger.info(translations["loading_model"].format(model_filename=model_filename))
|
277 |
-
load_model_start_time = time.perf_counter()
|
278 |
-
model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
|
279 |
-
self.logger.debug(translations["download_model_friendly_2"].format(model_friendly_name=model_friendly_name, model_path=model_path))
|
280 |
-
|
281 |
-
if model_path.lower().endswith(".yaml"): yaml_config_filename = model_path
|
282 |
-
|
283 |
-
common_params = {"logger": self.logger, "log_level": self.log_level, "torch_device": self.torch_device, "torch_device_cpu": self.torch_device_cpu, "torch_device_mps": self.torch_device_mps, "onnx_execution_provider": self.onnx_execution_provider, "model_name": model_filename.split(".")[0], "model_path": model_path, "model_data": self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename is not None else self.load_model_data_using_hash(model_path), "output_format": self.output_format, "output_bitrate": self.output_bitrate, "output_dir": self.output_dir, "normalization_threshold": self.normalization_threshold, "output_single_stem": self.output_single_stem, "invert_using_spec": self.invert_using_spec, "sample_rate": self.sample_rate}
|
284 |
-
separator_classes = {"MDX": "mdx_separator.MDXSeparator", "Demucs": "demucs_separator.DemucsSeparator"}
|
285 |
-
|
286 |
-
if model_type not in self.arch_specific_params or model_type not in separator_classes: raise ValueError(translations["model_type_not_support"].format(model_type=model_type))
|
287 |
-
if model_type == "Demucs" and sys.version_info < (3, 10): raise Exception(translations["demucs_not_support_python<3.10"])
|
288 |
-
|
289 |
-
self.logger.debug(f"{translations['import_module']} {model_type}: {separator_classes[model_type]}")
|
290 |
-
module_name, class_name = separator_classes[model_type].split(".")
|
291 |
-
separator_class = getattr(import_module(f"main.library.architectures.{module_name}"), class_name)
|
292 |
-
|
293 |
-
self.logger.debug(f"{translations['initialization']} {model_type}: {separator_class}")
|
294 |
-
self.model_instance = separator_class(common_config=common_params, arch_config=self.arch_specific_params[model_type])
|
295 |
-
|
296 |
-
self.logger.debug(translations["loading_model_success"])
|
297 |
-
self.logger.info(f"{translations['loading_model_duration']}: {time.strftime('%H:%M:%S', time.gmtime(int(time.perf_counter() - load_model_start_time)))}")
|
298 |
-
|
299 |
-
def separate(self, audio_file_path):
|
300 |
-
self.logger.info(f"{translations['starting_separator']}: {audio_file_path}")
|
301 |
-
separate_start_time = time.perf_counter()
|
302 |
-
|
303 |
-
self.logger.debug(translations["normalization"].format(normalization_threshold=self.normalization_threshold))
|
304 |
-
output_files = self.model_instance.separate(audio_file_path)
|
305 |
-
|
306 |
-
self.model_instance.clear_gpu_cache()
|
307 |
-
self.model_instance.clear_file_specific_paths()
|
308 |
-
|
309 |
-
self.print_uvr_vip_message()
|
310 |
-
|
311 |
-
self.logger.debug(translations["separator_success_3"])
|
312 |
-
self.logger.info(f"{translations['separator_duration']}: {time.strftime('%H:%M:%S', time.gmtime(int(time.perf_counter() - separate_start_time)))}")
|
313 |
-
return output_files
|
314 |
-
|
315 |
-
def download_model_and_data(self, model_filename):
|
316 |
-
self.logger.info(translations["loading_separator_model"].format(model_filename=model_filename))
|
317 |
-
model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
|
318 |
-
|
319 |
-
if model_path.lower().endswith(".yaml"): yaml_config_filename = model_path
|
320 |
-
self.logger.info(translations["downloading_model"].format(model_type=model_type, model_friendly_name=model_friendly_name, model_path=model_path, model_data_dict_size=len(self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename is not None else self.load_model_data_using_hash(model_path))))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/stftpitchshift.py
DELETED
@@ -1,250 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
|
3 |
-
from numpy.lib.stride_tricks import sliding_window_view
|
4 |
-
|
5 |
-
def istft(frames, framesize, hopsize):
|
6 |
-
frames = np.atleast_2d(frames)
|
7 |
-
assert frames.ndim == 2
|
8 |
-
|
9 |
-
analysis_window_size = np.ravel(framesize)[0]
|
10 |
-
synthesis_window_size = np.ravel(framesize)[-1]
|
11 |
-
|
12 |
-
assert analysis_window_size >= synthesis_window_size
|
13 |
-
|
14 |
-
A = asymmetric_analysis_window(analysis_window_size, synthesis_window_size) if analysis_window_size != synthesis_window_size else symmetric_window(analysis_window_size)
|
15 |
-
S = asymmetric_synthesis_window(analysis_window_size, synthesis_window_size) if analysis_window_size != synthesis_window_size else symmetric_window(synthesis_window_size)
|
16 |
-
|
17 |
-
W = S * hopsize / np.sum(A * S)
|
18 |
-
N = frames.shape[0] * hopsize + analysis_window_size
|
19 |
-
|
20 |
-
y = np.zeros((N), float)
|
21 |
-
|
22 |
-
frames[:, 0] = 0
|
23 |
-
frames[:, -1] = 0
|
24 |
-
frames0 = sliding_window_view(y, analysis_window_size, writeable=True)[::hopsize]
|
25 |
-
frames1 = np.fft.irfft(frames, axis=-1, norm='forward') * W
|
26 |
-
|
27 |
-
for i in range(min(len(frames0), len(frames1))):
|
28 |
-
frames0[i] += frames1[i]
|
29 |
-
|
30 |
-
return y
|
31 |
-
|
32 |
-
def asymmetric_synthesis_window(analysis_window_size, synthesis_window_size):
|
33 |
-
n = analysis_window_size
|
34 |
-
m = synthesis_window_size // 2
|
35 |
-
|
36 |
-
right = symmetric_window(2 * m)
|
37 |
-
window = np.zeros(n)
|
38 |
-
|
39 |
-
window[n-m-m:n-m] = np.square(right[:m]) / symmetric_window(2 * n - 2 * m)[n-m-m:n-m]
|
40 |
-
window[-m:] = right[-m:]
|
41 |
-
|
42 |
-
return window
|
43 |
-
|
44 |
-
def asymmetric_analysis_window(analysis_window_size, synthesis_window_size):
|
45 |
-
n = analysis_window_size
|
46 |
-
m = synthesis_window_size // 2
|
47 |
-
|
48 |
-
window = np.zeros(n)
|
49 |
-
window[:n-m] = symmetric_window(2 * n - 2 * m)[:n-m]
|
50 |
-
window[-m:] = symmetric_window(2 * m)[-m:]
|
51 |
-
|
52 |
-
return window
|
53 |
-
|
54 |
-
def symmetric_window(symmetric_window_size):
|
55 |
-
n = symmetric_window_size
|
56 |
-
window = 0.5 - 0.5 * np.cos(2 * np.pi * np.arange(n) / n)
|
57 |
-
|
58 |
-
return window
|
59 |
-
|
60 |
-
def stft(x, framesize, hopsize):
|
61 |
-
x = np.atleast_1d(x)
|
62 |
-
assert x.ndim == 1
|
63 |
-
|
64 |
-
analysis_window_size = np.ravel(framesize)[0]
|
65 |
-
synthesis_window_size = np.ravel(framesize)[-1]
|
66 |
-
|
67 |
-
assert analysis_window_size >= synthesis_window_size
|
68 |
-
|
69 |
-
W = asymmetric_analysis_window(analysis_window_size, synthesis_window_size) if analysis_window_size != synthesis_window_size else symmetric_window(analysis_window_size)
|
70 |
-
|
71 |
-
frames0 = sliding_window_view(x, analysis_window_size, writeable=False)[::hopsize]
|
72 |
-
frames1 = np.fft.rfft(frames0 * W, axis=-1, norm='forward')
|
73 |
-
|
74 |
-
return frames1
|
75 |
-
|
76 |
-
def normalize(frames, frames0):
|
77 |
-
for i in range(len(frames)):
|
78 |
-
a = np.real(frames0[i])
|
79 |
-
b = np.real(frames[i])
|
80 |
-
a = np.dot(a, a)
|
81 |
-
b = np.dot(b, b)
|
82 |
-
|
83 |
-
if b == 0: continue
|
84 |
-
frames[i] = np.real(frames[i]) * np.sqrt(a / b) + 1j * np.imag(frames[i])
|
85 |
-
|
86 |
-
return frames
|
87 |
-
|
88 |
-
def lowpass(cepstrum, quefrency):
|
89 |
-
cepstrum[1:quefrency] *= 2
|
90 |
-
cepstrum[quefrency+1:] = 0
|
91 |
-
|
92 |
-
return cepstrum
|
93 |
-
|
94 |
-
def lifter(frames, quefrency):
|
95 |
-
envelopes = np.zeros(frames.shape)
|
96 |
-
|
97 |
-
for i, frame in enumerate(frames):
|
98 |
-
with np.errstate(divide='ignore', invalid='ignore'):
|
99 |
-
spectrum = np.log10(np.real(frame))
|
100 |
-
|
101 |
-
envelopes[i] = np.power(10, np.real(np.fft.rfft(lowpass(np.fft.irfft(spectrum, norm='forward'), quefrency), norm='forward')))
|
102 |
-
|
103 |
-
return envelopes
|
104 |
-
|
105 |
-
def resample(x, factor):
|
106 |
-
if factor == 1: return x.copy()
|
107 |
-
y = np.zeros(x.shape, dtype=x.dtype)
|
108 |
-
|
109 |
-
n = len(x)
|
110 |
-
m = int(n * factor)
|
111 |
-
|
112 |
-
i = np.arange(min(n, m))
|
113 |
-
k = i * (n / m)
|
114 |
-
|
115 |
-
j = np.trunc(k).astype(int)
|
116 |
-
k = k - j
|
117 |
-
|
118 |
-
ok = (0 <= j) & (j < n - 1)
|
119 |
-
y[i[ok]] = k[ok] * x[j[ok] + 1] + (1 - k[ok]) * x[j[ok]]
|
120 |
-
|
121 |
-
return y
|
122 |
-
|
123 |
-
def shiftpitch(frames, factors, samplerate):
|
124 |
-
for i in range(len(frames)):
|
125 |
-
magnitudes = np.vstack([resample(np.real(frames[i]), factor) for factor in factors])
|
126 |
-
frequencies = np.vstack([resample(np.imag(frames[i]), factor) * factor for factor in factors])
|
127 |
-
|
128 |
-
magnitudes[(frequencies <= 0) | (frequencies >= samplerate / 2)] = 0
|
129 |
-
mask = np.argmax(magnitudes, axis=0)
|
130 |
-
|
131 |
-
magnitudes = np.take_along_axis(magnitudes, mask[None,:], axis=0)
|
132 |
-
frequencies = np.take_along_axis(frequencies, mask[None,:], axis=0)
|
133 |
-
|
134 |
-
frames[i] = magnitudes + 1j * frequencies
|
135 |
-
|
136 |
-
return frames
|
137 |
-
|
138 |
-
def wrap(x):
|
139 |
-
return (x + np.pi) % (2 * np.pi) - np.pi
|
140 |
-
|
141 |
-
def encode(frames, framesize, hopsize, samplerate):
|
142 |
-
M, N = frames.shape
|
143 |
-
analysis_framesize = np.ravel(framesize)[0]
|
144 |
-
|
145 |
-
freqinc = samplerate / analysis_framesize
|
146 |
-
phaseinc = 2 * np.pi * hopsize / analysis_framesize
|
147 |
-
|
148 |
-
buffer = np.zeros(N)
|
149 |
-
data = np.zeros((M, N), complex)
|
150 |
-
|
151 |
-
for m, frame in enumerate(frames):
|
152 |
-
arg = np.angle(frame)
|
153 |
-
delta = arg - buffer
|
154 |
-
|
155 |
-
buffer = arg
|
156 |
-
|
157 |
-
i = np.arange(N)
|
158 |
-
data[m] = np.abs(frame) + 1j * ((i + (wrap(delta - i * phaseinc) / phaseinc)) * freqinc)
|
159 |
-
|
160 |
-
return data
|
161 |
-
|
162 |
-
def decode(frames, framesize, hopsize, samplerate):
|
163 |
-
M, N = frames.shape
|
164 |
-
analysis_framesize = np.ravel(framesize)[0]
|
165 |
-
synthesis_framesize = np.ravel(framesize)[-1]
|
166 |
-
|
167 |
-
freqinc = samplerate / analysis_framesize
|
168 |
-
phaseinc = 2 * np.pi * hopsize / analysis_framesize
|
169 |
-
timeshift = 2 * np.pi * synthesis_framesize * np.arange(N) / N if synthesis_framesize != analysis_framesize else 0
|
170 |
-
|
171 |
-
buffer = np.zeros(N)
|
172 |
-
data = np.zeros((M, N), complex)
|
173 |
-
|
174 |
-
for m, frame in enumerate(frames):
|
175 |
-
i = np.arange(N)
|
176 |
-
delta = (i + ((np.imag(frame) - i * freqinc) / freqinc)) * phaseinc
|
177 |
-
buffer += delta
|
178 |
-
arg = buffer.copy()
|
179 |
-
arg -= timeshift
|
180 |
-
data[m] = np.real(frame) * np.exp(1j * arg)
|
181 |
-
|
182 |
-
return data
|
183 |
-
|
184 |
-
class StftPitchShift:
|
185 |
-
def __init__(self, framesize, hopsize, samplerate):
|
186 |
-
self.framesize = framesize
|
187 |
-
self.hopsize = hopsize
|
188 |
-
self.samplerate = samplerate
|
189 |
-
|
190 |
-
def shiftpitch(self, input, factors = 1, quefrency = 0, distortion = 1, normalization = False):
|
191 |
-
input = np.atleast_1d(input)
|
192 |
-
dtype = input.dtype
|
193 |
-
shape = input.shape
|
194 |
-
|
195 |
-
input = np.squeeze(input)
|
196 |
-
if input.ndim != 1: raise ValueError('input.ndim != 1')
|
197 |
-
|
198 |
-
if np.issubdtype(dtype, np.integer):
|
199 |
-
a, b = np.iinfo(dtype).min, np.iinfo(dtype).max
|
200 |
-
input = ((input.astype(float) - a) / (b - a)) * 2 - 1
|
201 |
-
elif not np.issubdtype(dtype, np.floating): raise TypeError('not np.issubdtype(dtype, np.floating)')
|
202 |
-
|
203 |
-
def isnotnormal(x):
|
204 |
-
return (np.isinf(x)) | (np.isnan(x)) | (abs(x) < np.finfo(x.dtype).tiny)
|
205 |
-
|
206 |
-
framesize = self.framesize
|
207 |
-
hopsize = self.hopsize
|
208 |
-
samplerate = self.samplerate
|
209 |
-
|
210 |
-
factors = np.asarray(factors).flatten()
|
211 |
-
quefrency = int(quefrency * samplerate)
|
212 |
-
|
213 |
-
frames = encode(stft(input, framesize, hopsize), framesize, hopsize, samplerate)
|
214 |
-
|
215 |
-
if normalization: frames0 = frames.copy()
|
216 |
-
|
217 |
-
if quefrency:
|
218 |
-
envelopes = lifter(frames, quefrency)
|
219 |
-
mask = isnotnormal(envelopes)
|
220 |
-
|
221 |
-
frames.real /= envelopes
|
222 |
-
frames.real[mask] = 0
|
223 |
-
|
224 |
-
if distortion != 1:
|
225 |
-
envelopes[mask] = 0
|
226 |
-
|
227 |
-
for i in range(len(envelopes)):
|
228 |
-
envelopes[i] = resample(envelopes[i], distortion)
|
229 |
-
|
230 |
-
mask = isnotnormal(envelopes)
|
231 |
-
|
232 |
-
frames = shiftpitch(frames, factors, samplerate)
|
233 |
-
frames.real *= envelopes
|
234 |
-
frames.real[mask] = 0
|
235 |
-
else: frames = shiftpitch(frames, factors, samplerate)
|
236 |
-
|
237 |
-
if normalization: frames = normalize(frames, frames0)
|
238 |
-
|
239 |
-
output = istft(decode(frames, framesize, hopsize, samplerate), framesize, hopsize)
|
240 |
-
output.resize(shape, refcheck=False)
|
241 |
-
|
242 |
-
if np.issubdtype(dtype, np.integer):
|
243 |
-
a, b = np.iinfo(dtype).min, np.iinfo(dtype).max
|
244 |
-
output = (((output + 1) / 2) * (b - a) + a).clip(a, b).astype(dtype)
|
245 |
-
elif output.dtype != dtype: output = output.astype(dtype)
|
246 |
-
|
247 |
-
assert output.dtype == dtype
|
248 |
-
assert output.shape == shape
|
249 |
-
|
250 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/synthesizers.py
DELETED
@@ -1,490 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import math
|
4 |
-
import torch
|
5 |
-
import numpy as np
|
6 |
-
import torch.nn.functional as F
|
7 |
-
|
8 |
-
from torch.nn.utils import remove_weight_norm
|
9 |
-
from torch.utils.checkpoint import checkpoint
|
10 |
-
from torch.nn.utils.parametrizations import weight_norm
|
11 |
-
|
12 |
-
sys.path.append(os.getcwd())
|
13 |
-
|
14 |
-
from .modules import WaveNet
|
15 |
-
from .refinegan import RefineGANGenerator
|
16 |
-
from .mrf_hifigan import HiFiGANMRFGenerator
|
17 |
-
from .residuals import ResidualCouplingBlock, ResBlock, LRELU_SLOPE
|
18 |
-
from .commons import init_weights, slice_segments, rand_slice_segments, sequence_mask, convert_pad_shape
|
19 |
-
|
20 |
-
|
21 |
-
class Generator(torch.nn.Module):
|
22 |
-
def __init__(self, initial_channel, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
|
23 |
-
super(Generator, self).__init__()
|
24 |
-
self.num_kernels = len(resblock_kernel_sizes)
|
25 |
-
self.num_upsamples = len(upsample_rates)
|
26 |
-
self.conv_pre = torch.nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
27 |
-
self.ups_and_resblocks = torch.nn.ModuleList()
|
28 |
-
|
29 |
-
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
30 |
-
self.ups_and_resblocks.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2)))
|
31 |
-
ch = upsample_initial_channel // (2 ** (i + 1))
|
32 |
-
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
33 |
-
self.ups_and_resblocks.append(ResBlock(ch, k, d))
|
34 |
-
|
35 |
-
self.conv_post = torch.nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
36 |
-
self.ups_and_resblocks.apply(init_weights)
|
37 |
-
if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
38 |
-
|
39 |
-
def forward(self, x, g = None):
|
40 |
-
x = self.conv_pre(x)
|
41 |
-
if g is not None: x = x + self.cond(g)
|
42 |
-
|
43 |
-
resblock_idx = 0
|
44 |
-
|
45 |
-
for _ in range(self.num_upsamples):
|
46 |
-
x = self.ups_and_resblocks[resblock_idx](F.leaky_relu(x, LRELU_SLOPE))
|
47 |
-
resblock_idx += 1
|
48 |
-
xs = 0
|
49 |
-
|
50 |
-
for _ in range(self.num_kernels):
|
51 |
-
xs += self.ups_and_resblocks[resblock_idx](x)
|
52 |
-
resblock_idx += 1
|
53 |
-
|
54 |
-
x = xs / self.num_kernels
|
55 |
-
|
56 |
-
return torch.tanh(self.conv_post(F.leaky_relu(x)))
|
57 |
-
|
58 |
-
def __prepare_scriptable__(self):
|
59 |
-
for l in self.ups_and_resblocks:
|
60 |
-
for hook in l._forward_pre_hooks.values():
|
61 |
-
if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(l)
|
62 |
-
|
63 |
-
return self
|
64 |
-
|
65 |
-
def remove_weight_norm(self):
|
66 |
-
for l in self.ups_and_resblocks:
|
67 |
-
remove_weight_norm(l)
|
68 |
-
|
69 |
-
class SineGen(torch.nn.Module):
|
70 |
-
def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0, flag_for_pulse=False):
|
71 |
-
super(SineGen, self).__init__()
|
72 |
-
self.sine_amp = sine_amp
|
73 |
-
self.noise_std = noise_std
|
74 |
-
self.harmonic_num = harmonic_num
|
75 |
-
self.dim = self.harmonic_num + 1
|
76 |
-
self.sampling_rate = samp_rate
|
77 |
-
self.voiced_threshold = voiced_threshold
|
78 |
-
|
79 |
-
def _f02uv(self, f0):
|
80 |
-
return torch.ones_like(f0) * (f0 > self.voiced_threshold)
|
81 |
-
|
82 |
-
def _f02sine(self, f0, upp):
|
83 |
-
rad = f0 / self.sampling_rate * torch.arange(1, upp + 1, dtype=f0.dtype, device=f0.device)
|
84 |
-
rad += F.pad((torch.fmod(rad[:, :-1, -1:].float() + 0.5, 1.0) - 0.5).cumsum(dim=1).fmod(1.0).to(f0), (0, 0, 1, 0), mode='constant')
|
85 |
-
rad = rad.reshape(f0.shape[0], -1, 1)
|
86 |
-
rad *= torch.arange(1, self.dim + 1, dtype=f0.dtype, device=f0.device).reshape(1, 1, -1)
|
87 |
-
rand_ini = torch.rand(1, 1, self.dim, device=f0.device)
|
88 |
-
rand_ini[..., 0] = 0
|
89 |
-
rad += rand_ini
|
90 |
-
|
91 |
-
return torch.sin(2 * np.pi * rad)
|
92 |
-
|
93 |
-
def forward(self, f0, upp):
|
94 |
-
with torch.no_grad():
|
95 |
-
f0 = f0.unsqueeze(-1)
|
96 |
-
sine_waves = self._f02sine(f0, upp) * self.sine_amp
|
97 |
-
uv = F.interpolate(self._f02uv(f0).transpose(2, 1), scale_factor=float(upp), mode="nearest").transpose(2, 1)
|
98 |
-
sine_waves = sine_waves * uv + ((uv * self.noise_std + (1 - uv) * self.sine_amp / 3) * torch.randn_like(sine_waves))
|
99 |
-
|
100 |
-
return sine_waves
|
101 |
-
|
102 |
-
class SourceModuleHnNSF(torch.nn.Module):
|
103 |
-
def __init__(self, sample_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0):
|
104 |
-
super(SourceModuleHnNSF, self).__init__()
|
105 |
-
self.sine_amp = sine_amp
|
106 |
-
self.noise_std = add_noise_std
|
107 |
-
self.l_sin_gen = SineGen(sample_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
|
108 |
-
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
109 |
-
self.l_tanh = torch.nn.Tanh()
|
110 |
-
|
111 |
-
def forward(self, x, upsample_factor = 1):
|
112 |
-
return self.l_tanh(self.l_linear(self.l_sin_gen(x, upsample_factor).to(dtype=self.l_linear.weight.dtype)))
|
113 |
-
|
114 |
-
class GeneratorNSF(torch.nn.Module):
|
115 |
-
def __init__(self, initial_channel, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels, sr, checkpointing = False):
|
116 |
-
super(GeneratorNSF, self).__init__()
|
117 |
-
self.num_kernels = len(resblock_kernel_sizes)
|
118 |
-
self.num_upsamples = len(upsample_rates)
|
119 |
-
self.upp = math.prod(upsample_rates)
|
120 |
-
self.f0_upsamp = torch.nn.Upsample(scale_factor=self.upp)
|
121 |
-
self.m_source = SourceModuleHnNSF(sample_rate=sr, harmonic_num=0)
|
122 |
-
|
123 |
-
self.conv_pre = torch.nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
124 |
-
self.checkpointing = checkpointing
|
125 |
-
|
126 |
-
self.ups = torch.nn.ModuleList()
|
127 |
-
self.noise_convs = torch.nn.ModuleList()
|
128 |
-
|
129 |
-
channels = [upsample_initial_channel // (2 ** (i + 1)) for i in range(self.num_upsamples)]
|
130 |
-
stride_f0s = [math.prod(upsample_rates[i + 1 :]) if i + 1 < self.num_upsamples else 1 for i in range(self.num_upsamples)]
|
131 |
-
|
132 |
-
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
133 |
-
self.ups.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), channels[i], k, u, padding=((k - u) // 2) if u % 2 == 0 else (u // 2 + u % 2), output_padding=u % 2)))
|
134 |
-
stride = stride_f0s[i]
|
135 |
-
kernel = 1 if stride == 1 else stride * 2 - stride % 2
|
136 |
-
self.noise_convs.append(torch.nn.Conv1d(1, channels[i], kernel_size=kernel, stride=stride, padding=0 if stride == 1 else (kernel - stride) // 2))
|
137 |
-
|
138 |
-
self.resblocks = torch.nn.ModuleList([ResBlock(channels[i], k, d) for i in range(len(self.ups)) for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes)])
|
139 |
-
self.conv_post = torch.nn.Conv1d(channels[-1], 1, 7, 1, padding=3, bias=False)
|
140 |
-
|
141 |
-
self.ups.apply(init_weights)
|
142 |
-
if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
143 |
-
|
144 |
-
def forward(self, x, f0, g = None):
|
145 |
-
har_source = self.m_source(f0, self.upp).transpose(1, 2)
|
146 |
-
x = self.conv_pre(x)
|
147 |
-
if g is not None: x += self.cond(g)
|
148 |
-
|
149 |
-
for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
|
150 |
-
x = F.leaky_relu(x, LRELU_SLOPE)
|
151 |
-
|
152 |
-
if self.training and self.checkpointing:
|
153 |
-
x = checkpoint(ups, x, use_reentrant=False) + noise_convs(har_source)
|
154 |
-
xs = sum([checkpoint(resblock, x, use_reentrant=False) for j, resblock in enumerate(self.resblocks) if j in range(i * self.num_kernels, (i + 1) * self.num_kernels)])
|
155 |
-
else:
|
156 |
-
x = ups(x) + noise_convs(har_source)
|
157 |
-
xs = sum([resblock(x) for j, resblock in enumerate(self.resblocks) if j in range(i * self.num_kernels, (i + 1) * self.num_kernels)])
|
158 |
-
|
159 |
-
x = xs / self.num_kernels
|
160 |
-
|
161 |
-
return torch.tanh(self.conv_post(F.leaky_relu(x)))
|
162 |
-
|
163 |
-
def remove_weight_norm(self):
|
164 |
-
for l in self.ups:
|
165 |
-
remove_weight_norm(l)
|
166 |
-
|
167 |
-
for l in self.resblocks:
|
168 |
-
l.remove_weight_norm()
|
169 |
-
|
170 |
-
class LayerNorm(torch.nn.Module):
|
171 |
-
def __init__(self, channels, eps=1e-5, onnx=False):
|
172 |
-
super().__init__()
|
173 |
-
self.channels = channels
|
174 |
-
self.eps = eps
|
175 |
-
self.onnx = onnx
|
176 |
-
self.gamma = torch.nn.Parameter(torch.ones(channels))
|
177 |
-
self.beta = torch.nn.Parameter(torch.zeros(channels))
|
178 |
-
|
179 |
-
def forward(self, x):
|
180 |
-
x = x.transpose(1, -1)
|
181 |
-
return (F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) if self.onnx else F.layer_norm(x, (x.size(-1),), self.gamma, self.beta, self.eps)).transpose(1, -1)
|
182 |
-
|
183 |
-
class MultiHeadAttention(torch.nn.Module):
|
184 |
-
def __init__(self, channels, out_channels, n_heads, p_dropout=0.0, window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False, onnx=False):
|
185 |
-
super().__init__()
|
186 |
-
assert channels % n_heads == 0
|
187 |
-
self.channels = channels
|
188 |
-
self.out_channels = out_channels
|
189 |
-
self.n_heads = n_heads
|
190 |
-
self.p_dropout = p_dropout
|
191 |
-
self.window_size = window_size
|
192 |
-
self.heads_share = heads_share
|
193 |
-
self.block_length = block_length
|
194 |
-
self.proximal_bias = proximal_bias
|
195 |
-
self.proximal_init = proximal_init
|
196 |
-
self.onnx = onnx
|
197 |
-
self.attn = None
|
198 |
-
self.k_channels = channels // n_heads
|
199 |
-
self.conv_q = torch.nn.Conv1d(channels, channels, 1)
|
200 |
-
self.conv_k = torch.nn.Conv1d(channels, channels, 1)
|
201 |
-
self.conv_v = torch.nn.Conv1d(channels, channels, 1)
|
202 |
-
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
|
203 |
-
self.drop = torch.nn.Dropout(p_dropout)
|
204 |
-
|
205 |
-
if window_size is not None:
|
206 |
-
n_heads_rel = 1 if heads_share else n_heads
|
207 |
-
rel_stddev = self.k_channels**-0.5
|
208 |
-
|
209 |
-
self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
210 |
-
self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
211 |
-
|
212 |
-
torch.nn.init.xavier_uniform_(self.conv_q.weight)
|
213 |
-
torch.nn.init.xavier_uniform_(self.conv_k.weight)
|
214 |
-
torch.nn.init.xavier_uniform_(self.conv_v.weight)
|
215 |
-
|
216 |
-
if proximal_init:
|
217 |
-
with torch.no_grad():
|
218 |
-
self.conv_k.weight.copy_(self.conv_q.weight)
|
219 |
-
self.conv_k.bias.copy_(self.conv_q.bias)
|
220 |
-
|
221 |
-
def forward(self, x, c, attn_mask=None):
|
222 |
-
q, k, v = self.conv_q(x), self.conv_k(c), self.conv_v(c)
|
223 |
-
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
224 |
-
|
225 |
-
return self.conv_o(x)
|
226 |
-
|
227 |
-
def attention(self, query, key, value, mask=None):
|
228 |
-
b, d, t_s, t_t = (*key.size(), query.size(2))
|
229 |
-
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
230 |
-
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
231 |
-
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
232 |
-
|
233 |
-
if self.window_size is not None:
|
234 |
-
assert (t_s == t_t), "(t_s == t_t)"
|
235 |
-
scores = scores + self._relative_position_to_absolute_position(self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), self._get_relative_embeddings(self.emb_rel_k, t_s, onnx=self.onnx)), onnx=self.onnx)
|
236 |
-
|
237 |
-
if self.proximal_bias:
|
238 |
-
assert t_s == t_t, "t_s == t_t"
|
239 |
-
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
240 |
-
|
241 |
-
if mask is not None:
|
242 |
-
scores = scores.masked_fill(mask == 0, -1e4)
|
243 |
-
if self.block_length is not None:
|
244 |
-
assert (t_s == t_t), "(t_s == t_t)"
|
245 |
-
scores = scores.masked_fill((torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)) == 0, -1e4)
|
246 |
-
|
247 |
-
p_attn = self.drop(F.softmax(scores, dim=-1))
|
248 |
-
output = torch.matmul(p_attn, value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3))
|
249 |
-
|
250 |
-
if self.window_size is not None: output = output + self._matmul_with_relative_values(self._absolute_position_to_relative_position(p_attn, onnx=self.onnx), self._get_relative_embeddings(self.emb_rel_v, t_s, onnx=self.onnx))
|
251 |
-
return (output.transpose(2, 3).contiguous().view(b, d, t_t)), p_attn
|
252 |
-
|
253 |
-
def _matmul_with_relative_values(self, x, y):
|
254 |
-
return torch.matmul(x, y.unsqueeze(0))
|
255 |
-
|
256 |
-
def _matmul_with_relative_keys(self, x, y):
|
257 |
-
return torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
258 |
-
|
259 |
-
def _get_relative_embeddings(self, relative_embeddings, length, onnx=False):
|
260 |
-
if onnx:
|
261 |
-
pad_length = torch.clamp(length - (self.window_size + 1), min=0)
|
262 |
-
slice_start_position = torch.clamp((self.window_size + 1) - length, min=0)
|
263 |
-
|
264 |
-
return (F.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0]) if pad_length > 0 else relative_embeddings)[:, slice_start_position:(slice_start_position + 2 * length - 1)]
|
265 |
-
else:
|
266 |
-
pad_length = max(length - (self.window_size + 1), 0)
|
267 |
-
slice_start_position = max((self.window_size + 1) - length, 0)
|
268 |
-
|
269 |
-
return (F.pad(relative_embeddings, convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) if pad_length > 0 else relative_embeddings)[:, slice_start_position:(slice_start_position + 2 * length - 1)]
|
270 |
-
|
271 |
-
def _relative_position_to_absolute_position(self, x, onnx=False):
|
272 |
-
batch, heads, length, _ = x.size()
|
273 |
-
|
274 |
-
return (F.pad(F.pad(x, [0, 1, 0, 0, 0, 0, 0, 0]).view([batch, heads, length * 2 * length]), [0, length - 1, 0, 0, 0, 0]).view([batch, heads, length + 1, 2 * length - 1]) if onnx else F.pad(F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])).view([batch, heads, length * 2 * length]), convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])).view([batch, heads, length + 1, 2 * length - 1]))[:, :, :length, length - 1 :]
|
275 |
-
|
276 |
-
def _absolute_position_to_relative_position(self, x, onnx=False):
|
277 |
-
batch, heads, length, _ = x.size()
|
278 |
-
|
279 |
-
return (F.pad(F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0]).view([batch, heads, length*length + length * (length - 1)]), [length, 0, 0, 0, 0, 0]).view([batch, heads, length, 2 * length]) if onnx else F.pad(F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])).view([batch, heads, length**2 + length * (length - 1)]), convert_pad_shape([[0, 0], [0, 0], [length, 0]])).view([batch, heads, length, 2 * length]))[:, :, :, 1:]
|
280 |
-
|
281 |
-
def _attention_bias_proximal(self, length):
|
282 |
-
r = torch.arange(length, dtype=torch.float32)
|
283 |
-
|
284 |
-
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs((torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)))), 0), 0)
|
285 |
-
|
286 |
-
class FFN(torch.nn.Module):
|
287 |
-
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0, activation=None, causal=False, onnx=False):
|
288 |
-
super().__init__()
|
289 |
-
self.in_channels = in_channels
|
290 |
-
self.out_channels = out_channels
|
291 |
-
self.filter_channels = filter_channels
|
292 |
-
self.kernel_size = kernel_size
|
293 |
-
self.p_dropout = p_dropout
|
294 |
-
self.activation = activation
|
295 |
-
self.causal = causal
|
296 |
-
self.onnx = onnx
|
297 |
-
self.padding = self._causal_padding if causal else self._same_padding
|
298 |
-
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size)
|
299 |
-
self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size)
|
300 |
-
self.drop = torch.nn.Dropout(p_dropout)
|
301 |
-
|
302 |
-
def forward(self, x, x_mask):
|
303 |
-
x = self.conv_1(self.padding(x * x_mask))
|
304 |
-
|
305 |
-
return self.conv_2(self.padding(self.drop(((x * torch.sigmoid(1.702 * x)) if self.activation == "gelu" else torch.relu(x))) * x_mask)) * x_mask
|
306 |
-
|
307 |
-
def _causal_padding(self, x):
|
308 |
-
if self.kernel_size == 1: return x
|
309 |
-
|
310 |
-
return F.pad(x, [self.kernel_size - 1, 0, 0, 0, 0, 0]) if self.onnx else F.pad(x, convert_pad_shape([[0, 0], [0, 0], [(self.kernel_size - 1), 0]]))
|
311 |
-
|
312 |
-
def _same_padding(self, x):
|
313 |
-
if self.kernel_size == 1: return x
|
314 |
-
|
315 |
-
return F.pad(x, [(self.kernel_size - 1) // 2, self.kernel_size // 2, 0, 0, 0, 0]) if self.onnx else F.pad(x, convert_pad_shape([[0, 0], [0, 0], [((self.kernel_size - 1) // 2), (self.kernel_size // 2)]]))
|
316 |
-
|
317 |
-
class Encoder(torch.nn.Module):
|
318 |
-
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.0, window_size=10, onnx=False, **kwargs):
|
319 |
-
super().__init__()
|
320 |
-
self.hidden_channels = hidden_channels
|
321 |
-
self.filter_channels = filter_channels
|
322 |
-
self.n_heads = n_heads
|
323 |
-
self.n_layers = n_layers
|
324 |
-
self.kernel_size = kernel_size
|
325 |
-
self.p_dropout = p_dropout
|
326 |
-
self.window_size = window_size
|
327 |
-
self.drop = torch.nn.Dropout(p_dropout)
|
328 |
-
self.attn_layers = torch.nn.ModuleList()
|
329 |
-
self.norm_layers_1 = torch.nn.ModuleList()
|
330 |
-
self.ffn_layers = torch.nn.ModuleList()
|
331 |
-
self.norm_layers_2 = torch.nn.ModuleList()
|
332 |
-
|
333 |
-
for _ in range(self.n_layers):
|
334 |
-
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size, onnx=onnx))
|
335 |
-
self.norm_layers_1.append(LayerNorm(hidden_channels, onnx=onnx))
|
336 |
-
|
337 |
-
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, onnx=onnx))
|
338 |
-
self.norm_layers_2.append(LayerNorm(hidden_channels, onnx=onnx))
|
339 |
-
|
340 |
-
def forward(self, x, x_mask):
|
341 |
-
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
342 |
-
x = x * x_mask
|
343 |
-
|
344 |
-
for i in range(self.n_layers):
|
345 |
-
x = self.norm_layers_1[i](x + self.drop(self.attn_layers[i](x, x, attn_mask)))
|
346 |
-
x = self.norm_layers_2[i](x + self.drop(self.ffn_layers[i](x, x_mask)))
|
347 |
-
|
348 |
-
return x * x_mask
|
349 |
-
|
350 |
-
class TextEncoder(torch.nn.Module):
|
351 |
-
def __init__(self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, embedding_dim, f0=True, onnx=False):
|
352 |
-
super(TextEncoder, self).__init__()
|
353 |
-
self.out_channels = out_channels
|
354 |
-
self.hidden_channels = hidden_channels
|
355 |
-
self.filter_channels = filter_channels
|
356 |
-
self.n_heads = n_heads
|
357 |
-
self.n_layers = n_layers
|
358 |
-
self.kernel_size = kernel_size
|
359 |
-
self.p_dropout = float(p_dropout)
|
360 |
-
self.emb_phone = torch.nn.Linear(embedding_dim, hidden_channels)
|
361 |
-
self.lrelu = torch.nn.LeakyReLU(0.1, inplace=True)
|
362 |
-
if f0: self.emb_pitch = torch.nn.Embedding(256, hidden_channels)
|
363 |
-
self.encoder = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), onnx=onnx)
|
364 |
-
self.proj = torch.nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
365 |
-
|
366 |
-
def forward(self, phone, pitch, lengths):
|
367 |
-
x = torch.transpose(self.lrelu(((self.emb_phone(phone) if pitch is None else (self.emb_phone(phone) + self.emb_pitch(pitch))) * math.sqrt(self.hidden_channels))), 1, -1)
|
368 |
-
x_mask = torch.unsqueeze(sequence_mask(lengths, x.size(2)), 1).to(x.dtype)
|
369 |
-
m, logs = torch.split((self.proj(self.encoder(x * x_mask, x_mask)) * x_mask), self.out_channels, dim=1)
|
370 |
-
|
371 |
-
return m, logs, x_mask
|
372 |
-
|
373 |
-
class PosteriorEncoder(torch.nn.Module):
|
374 |
-
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0):
|
375 |
-
super(PosteriorEncoder, self).__init__()
|
376 |
-
self.in_channels = in_channels
|
377 |
-
self.out_channels = out_channels
|
378 |
-
self.hidden_channels = hidden_channels
|
379 |
-
self.kernel_size = kernel_size
|
380 |
-
self.dilation_rate = dilation_rate
|
381 |
-
self.n_layers = n_layers
|
382 |
-
self.gin_channels = gin_channels
|
383 |
-
self.pre = torch.nn.Conv1d(in_channels, hidden_channels, 1)
|
384 |
-
self.enc = WaveNet(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
|
385 |
-
self.proj = torch.nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
386 |
-
|
387 |
-
def forward(self, x, x_lengths, g = None):
|
388 |
-
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
389 |
-
m, logs = torch.split((self.proj(self.enc((self.pre(x) * x_mask), x_mask, g=g)) * x_mask), self.out_channels, dim=1)
|
390 |
-
|
391 |
-
return ((m + torch.randn_like(m) * torch.exp(logs)) * x_mask), m, logs, x_mask
|
392 |
-
|
393 |
-
def remove_weight_norm(self):
|
394 |
-
self.enc.remove_weight_norm()
|
395 |
-
|
396 |
-
class Synthesizer(torch.nn.Module):
|
397 |
-
def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, use_f0, text_enc_hidden_dim=768, vocoder="Default", checkpointing=False, onnx=False, **kwargs):
|
398 |
-
super(Synthesizer, self).__init__()
|
399 |
-
self.spec_channels = spec_channels
|
400 |
-
self.inter_channels = inter_channels
|
401 |
-
self.hidden_channels = hidden_channels
|
402 |
-
self.filter_channels = filter_channels
|
403 |
-
self.n_heads = n_heads
|
404 |
-
self.n_layers = n_layers
|
405 |
-
self.kernel_size = kernel_size
|
406 |
-
self.p_dropout = float(p_dropout)
|
407 |
-
self.resblock_kernel_sizes = resblock_kernel_sizes
|
408 |
-
self.resblock_dilation_sizes = resblock_dilation_sizes
|
409 |
-
self.upsample_rates = upsample_rates
|
410 |
-
self.upsample_initial_channel = upsample_initial_channel
|
411 |
-
self.upsample_kernel_sizes = upsample_kernel_sizes
|
412 |
-
self.segment_size = segment_size
|
413 |
-
self.gin_channels = gin_channels
|
414 |
-
self.spk_embed_dim = spk_embed_dim
|
415 |
-
self.use_f0 = use_f0
|
416 |
-
self.enc_p = TextEncoder(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), text_enc_hidden_dim, f0=use_f0, onnx=onnx)
|
417 |
-
|
418 |
-
if use_f0:
|
419 |
-
if vocoder == "RefineGAN": self.dec = RefineGANGenerator(sample_rate=sr, upsample_rates=upsample_rates, num_mels=inter_channels, checkpointing=checkpointing)
|
420 |
-
elif vocoder in ["MRF-HiFi-GAN", "MRF HiFi-GAN"]: self.dec = HiFiGANMRFGenerator(in_channel=inter_channels, upsample_initial_channel=upsample_initial_channel, upsample_rates=upsample_rates, upsample_kernel_sizes=upsample_kernel_sizes, resblock_kernel_sizes=resblock_kernel_sizes, resblock_dilations=resblock_dilation_sizes, gin_channels=gin_channels, sample_rate=sr, harmonic_num=8, checkpointing=checkpointing)
|
421 |
-
else: self.dec = GeneratorNSF(inter_channels, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, checkpointing=checkpointing)
|
422 |
-
else: self.dec = Generator(inter_channels, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
|
423 |
-
|
424 |
-
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
|
425 |
-
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels)
|
426 |
-
self.emb_g = torch.nn.Embedding(self.spk_embed_dim, gin_channels)
|
427 |
-
|
428 |
-
def remove_weight_norm(self):
|
429 |
-
self.dec.remove_weight_norm()
|
430 |
-
self.flow.remove_weight_norm()
|
431 |
-
self.enc_q.remove_weight_norm()
|
432 |
-
|
433 |
-
@torch.jit.ignore
|
434 |
-
def forward(self, phone, phone_lengths, pitch = None, pitchf = None, y = None, y_lengths = None, ds = None):
|
435 |
-
g = self.emb_g(ds).unsqueeze(-1)
|
436 |
-
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
437 |
-
|
438 |
-
if y is not None:
|
439 |
-
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
440 |
-
z_slice, ids_slice = rand_slice_segments(z, y_lengths, self.segment_size)
|
441 |
-
|
442 |
-
return (self.dec(z_slice, slice_segments(pitchf, ids_slice, self.segment_size, 2), g=g) if self.use_f0 else self.dec(z_slice, g=g)), ids_slice, x_mask, y_mask, (z, self.flow(z, y_mask, g=g), m_p, logs_p, m_q, logs_q)
|
443 |
-
else: return None, None, x_mask, None, (None, None, m_p, logs_p, None, None)
|
444 |
-
|
445 |
-
@torch.jit.export
|
446 |
-
def infer(self, phone, phone_lengths, pitch = None, nsff0 = None, sid = None, rate = None):
|
447 |
-
g = self.emb_g(sid).unsqueeze(-1)
|
448 |
-
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
449 |
-
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
450 |
-
|
451 |
-
if rate is not None:
|
452 |
-
assert isinstance(rate, torch.Tensor)
|
453 |
-
head = int(z_p.shape[2] * (1.0 - rate.item()))
|
454 |
-
z_p = z_p[:, :, head:]
|
455 |
-
x_mask = x_mask[:, :, head:]
|
456 |
-
if self.use_f0: nsff0 = nsff0[:, head:]
|
457 |
-
|
458 |
-
if self.use_f0:
|
459 |
-
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
460 |
-
o = self.dec(z * x_mask, nsff0, g=g)
|
461 |
-
else:
|
462 |
-
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
463 |
-
o = self.dec(z * x_mask, g=g)
|
464 |
-
|
465 |
-
return o, x_mask, (z, z_p, m_p, logs_p)
|
466 |
-
|
467 |
-
class SynthesizerONNX(Synthesizer):
|
468 |
-
def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, use_f0, text_enc_hidden_dim=768, vocoder="Default", checkpointing=False, **kwargs):
|
469 |
-
super().__init__(spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, use_f0, text_enc_hidden_dim, vocoder, checkpointing, True)
|
470 |
-
self.speaker_map = None
|
471 |
-
|
472 |
-
def remove_weight_norm(self):
|
473 |
-
self.dec.remove_weight_norm()
|
474 |
-
self.flow.remove_weight_norm()
|
475 |
-
self.enc_q.remove_weight_norm()
|
476 |
-
|
477 |
-
def construct_spkmixmap(self, n_speaker):
|
478 |
-
self.speaker_map = torch.zeros((n_speaker, 1, 1, self.gin_channels))
|
479 |
-
|
480 |
-
for i in range(n_speaker):
|
481 |
-
self.speaker_map[i] = self.emb_g(torch.LongTensor([[i]]))
|
482 |
-
|
483 |
-
self.speaker_map = self.speaker_map.unsqueeze(0)
|
484 |
-
|
485 |
-
def forward(self, phone, phone_lengths, g=None, rnd=None, pitch=None, nsff0=None, max_len=None):
|
486 |
-
g = self.emb_g(g).unsqueeze(-1)
|
487 |
-
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
488 |
-
z_p = (m_p + torch.exp(logs_p) * rnd) * x_mask
|
489 |
-
|
490 |
-
return self.dec((self.flow(z_p, x_mask, g=g, reverse=True) * x_mask)[:, :, :max_len], nsff0, g=g) if self.use_f0 else self.dec((self.flow(z_p, x_mask, g=g, reverse=True) * x_mask)[:, :, :max_len], g=g)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/architectures/demucs_separator.py
DELETED
@@ -1,180 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import yaml
|
4 |
-
import torch
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
from hashlib import sha256
|
8 |
-
|
9 |
-
sys.path.append(os.getcwd())
|
10 |
-
|
11 |
-
from main.configs.config import Config
|
12 |
-
from main.library.uvr5_separator import spec_utils, common_separator
|
13 |
-
from main.library.uvr5_separator.demucs import hdemucs, states, apply
|
14 |
-
|
15 |
-
translations = Config().translations
|
16 |
-
sys.path.insert(0, os.path.join(os.getcwd(), "main", "library", "uvr5_separator"))
|
17 |
-
|
18 |
-
DEMUCS_4_SOURCE_MAPPER = {common_separator.CommonSeparator.BASS_STEM: 0, common_separator.CommonSeparator.DRUM_STEM: 1, common_separator.CommonSeparator.OTHER_STEM: 2, common_separator.CommonSeparator.VOCAL_STEM: 3}
|
19 |
-
|
20 |
-
|
21 |
-
class DemucsSeparator(common_separator.CommonSeparator):
|
22 |
-
def __init__(self, common_config, arch_config):
|
23 |
-
super().__init__(config=common_config)
|
24 |
-
self.segment_size = arch_config.get("segment_size", "Default")
|
25 |
-
self.shifts = arch_config.get("shifts", 2)
|
26 |
-
self.overlap = arch_config.get("overlap", 0.25)
|
27 |
-
self.segments_enabled = arch_config.get("segments_enabled", True)
|
28 |
-
self.logger.debug(translations["demucs_info"].format(segment_size=self.segment_size, segments_enabled=self.segments_enabled))
|
29 |
-
self.logger.debug(translations["demucs_info_2"].format(shifts=self.shifts, overlap=self.overlap))
|
30 |
-
self.demucs_source_map = DEMUCS_4_SOURCE_MAPPER
|
31 |
-
self.audio_file_path = None
|
32 |
-
self.audio_file_base = None
|
33 |
-
self.demucs_model_instance = None
|
34 |
-
self.logger.info(translations["start_demucs"])
|
35 |
-
|
36 |
-
def separate(self, audio_file_path):
|
37 |
-
self.logger.debug(translations["start_separator"])
|
38 |
-
source = None
|
39 |
-
inst_source = {}
|
40 |
-
self.audio_file_path = audio_file_path
|
41 |
-
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
|
42 |
-
self.logger.debug(translations["prepare_mix"])
|
43 |
-
mix = self.prepare_mix(self.audio_file_path)
|
44 |
-
self.logger.debug(translations["demix"].format(shape=mix.shape))
|
45 |
-
self.logger.debug(translations["cancel_mix"])
|
46 |
-
self.demucs_model_instance = hdemucs.HDemucs(sources=["drums", "bass", "other", "vocals"])
|
47 |
-
self.demucs_model_instance = get_demucs_model(name=os.path.splitext(os.path.basename(self.model_path))[0], repo=os.path.dirname(self.model_path))
|
48 |
-
self.demucs_model_instance = apply.demucs_segments(self.segment_size, self.demucs_model_instance)
|
49 |
-
self.demucs_model_instance.to(self.torch_device)
|
50 |
-
self.demucs_model_instance.eval()
|
51 |
-
self.logger.debug(translations["model_review"])
|
52 |
-
source = self.demix_demucs(mix)
|
53 |
-
del self.demucs_model_instance
|
54 |
-
self.clear_gpu_cache()
|
55 |
-
self.logger.debug(translations["del_gpu_cache_after_demix"])
|
56 |
-
output_files = []
|
57 |
-
self.logger.debug(translations["process_output_file"])
|
58 |
-
|
59 |
-
if isinstance(inst_source, np.ndarray):
|
60 |
-
self.logger.debug(translations["process_ver"])
|
61 |
-
inst_source[self.demucs_source_map[common_separator.CommonSeparator.VOCAL_STEM]] = spec_utils.reshape_sources(inst_source[self.demucs_source_map[common_separator.CommonSeparator.VOCAL_STEM]], source[self.demucs_source_map[common_separator.CommonSeparator.VOCAL_STEM]])
|
62 |
-
source = inst_source
|
63 |
-
|
64 |
-
if isinstance(source, np.ndarray):
|
65 |
-
source_length = len(source)
|
66 |
-
self.logger.debug(translations["source_length"].format(source_length=source_length))
|
67 |
-
self.logger.debug(translations["set_map"].format(part=source_length))
|
68 |
-
|
69 |
-
match source_length:
|
70 |
-
case 2: self.demucs_source_map = {common_separator.CommonSeparator.INST_STEM: 0, common_separator.CommonSeparator.VOCAL_STEM: 1}
|
71 |
-
case 6: self.demucs_source_map = {common_separator.CommonSeparator.BASS_STEM: 0, common_separator.CommonSeparator.DRUM_STEM: 1, common_separator.CommonSeparator.OTHER_STEM: 2, common_separator.CommonSeparator.VOCAL_STEM: 3, common_separator.CommonSeparator.GUITAR_STEM: 4, common_separator.CommonSeparator.PIANO_STEM: 5}
|
72 |
-
case _: self.demucs_source_map = DEMUCS_4_SOURCE_MAPPER
|
73 |
-
|
74 |
-
self.logger.debug(translations["process_all_part"])
|
75 |
-
|
76 |
-
for stem_name, stem_value in self.demucs_source_map.items():
|
77 |
-
if self.output_single_stem is not None:
|
78 |
-
if stem_name.lower() != self.output_single_stem.lower():
|
79 |
-
self.logger.debug(translations["skip_part"].format(stem_name=stem_name, output_single_stem=self.output_single_stem))
|
80 |
-
continue
|
81 |
-
|
82 |
-
stem_path = os.path.join(f"{self.audio_file_base}_({stem_name})_{self.model_name}.{self.output_format.lower()}")
|
83 |
-
self.final_process(stem_path, source[stem_value].T, stem_name)
|
84 |
-
output_files.append(stem_path)
|
85 |
-
|
86 |
-
return output_files
|
87 |
-
|
88 |
-
def demix_demucs(self, mix):
|
89 |
-
self.logger.debug(translations["starting_demix_demucs"])
|
90 |
-
processed = {}
|
91 |
-
mix = torch.tensor(mix, dtype=torch.float32)
|
92 |
-
ref = mix.mean(0)
|
93 |
-
mix = (mix - ref.mean()) / ref.std()
|
94 |
-
mix_infer = mix
|
95 |
-
|
96 |
-
with torch.no_grad():
|
97 |
-
self.logger.debug(translations["model_infer"])
|
98 |
-
sources = apply.apply_model(model=self.demucs_model_instance, mix=mix_infer[None], shifts=self.shifts, split=self.segments_enabled, overlap=self.overlap, static_shifts=1 if self.shifts == 0 else self.shifts, set_progress_bar=None, device=self.torch_device, progress=True)[0]
|
99 |
-
|
100 |
-
sources = (sources * ref.std() + ref.mean()).cpu().numpy()
|
101 |
-
sources[[0, 1]] = sources[[1, 0]]
|
102 |
-
|
103 |
-
processed[mix] = sources[:, :, 0:None].copy()
|
104 |
-
return np.concatenate([s[:, :, 0:None] for s in list(processed.values())], axis=-1)
|
105 |
-
|
106 |
-
class LocalRepo:
|
107 |
-
def __init__(self, root):
|
108 |
-
self.root = root
|
109 |
-
self.scan()
|
110 |
-
|
111 |
-
def scan(self):
|
112 |
-
self._models, self._checksums = {}, {}
|
113 |
-
for filename in os.listdir(self.root):
|
114 |
-
filepath = os.path.join(self.root, filename)
|
115 |
-
if not os.path.isfile(filepath): continue
|
116 |
-
|
117 |
-
if os.path.splitext(filename)[1] == ".th":
|
118 |
-
stem = os.path.splitext(filename)[0]
|
119 |
-
|
120 |
-
if "-" in stem:
|
121 |
-
xp_sig, checksum = stem.split("-", 1)
|
122 |
-
self._checksums[xp_sig] = checksum
|
123 |
-
else: xp_sig = stem
|
124 |
-
|
125 |
-
if xp_sig in self._models: raise RuntimeError(translations["del_all_but_one"].format(xp_sig=xp_sig))
|
126 |
-
self._models[xp_sig] = filepath
|
127 |
-
|
128 |
-
def has_model(self, sig):
|
129 |
-
return sig in self._models
|
130 |
-
|
131 |
-
def get_model(self, sig):
|
132 |
-
try:
|
133 |
-
file = self._models[sig]
|
134 |
-
except KeyError:
|
135 |
-
raise RuntimeError(translations["not_found_model_signature"].format(sig=sig))
|
136 |
-
|
137 |
-
if sig in self._checksums: check_checksum(file, self._checksums[sig])
|
138 |
-
return states.load_model(file)
|
139 |
-
|
140 |
-
class BagOnlyRepo:
|
141 |
-
def __init__(self, root, model_repo):
|
142 |
-
self.root = root
|
143 |
-
self.model_repo = model_repo
|
144 |
-
self.scan()
|
145 |
-
|
146 |
-
def scan(self):
|
147 |
-
self._bags = {}
|
148 |
-
for filename in os.listdir(self.root):
|
149 |
-
filepath = os.path.join(self.root, filename)
|
150 |
-
|
151 |
-
if os.path.isfile(filepath) and os.path.splitext(filename)[1] == ".yaml":
|
152 |
-
stem = os.path.splitext(filename)[0]
|
153 |
-
self._bags[stem] = filepath
|
154 |
-
|
155 |
-
def get_model(self, name):
|
156 |
-
try:
|
157 |
-
yaml_file = self._bags[name]
|
158 |
-
except KeyError:
|
159 |
-
raise RuntimeError(translations["name_not_pretrained"].format(name=name))
|
160 |
-
|
161 |
-
with open(yaml_file, 'r') as f:
|
162 |
-
bag = yaml.safe_load(f)
|
163 |
-
|
164 |
-
return apply.BagOfModels([self.model_repo.get_model(sig) for sig in bag["models"]], bag.get("weights"), bag.get("segment"))
|
165 |
-
|
166 |
-
def check_checksum(path, checksum):
|
167 |
-
sha = sha256()
|
168 |
-
|
169 |
-
with open(path, "rb") as file:
|
170 |
-
while 1:
|
171 |
-
buf = file.read(2**20)
|
172 |
-
if not buf: break
|
173 |
-
sha.update(buf)
|
174 |
-
|
175 |
-
actual_checksum = sha.hexdigest()[:len(checksum)]
|
176 |
-
if actual_checksum != checksum: raise RuntimeError(translations["invalid_checksum"].format(path=path, checksum=checksum, actual_checksum=actual_checksum))
|
177 |
-
|
178 |
-
def get_demucs_model(name, repo = None):
|
179 |
-
model_repo = LocalRepo(repo)
|
180 |
-
return (model_repo.get_model(name) if model_repo.has_model(name) else BagOnlyRepo(repo, model_repo).get_model(name)).eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/architectures/fairseq.py
DELETED
@@ -1,1480 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
import sys
|
3 |
-
import math
|
4 |
-
import uuid
|
5 |
-
import torch
|
6 |
-
import types
|
7 |
-
import contextlib
|
8 |
-
|
9 |
-
import numpy as np
|
10 |
-
import torch.nn.functional as F
|
11 |
-
|
12 |
-
from torch import nn
|
13 |
-
from omegaconf import DictConfig, open_dict
|
14 |
-
|
15 |
-
class Dictionary:
|
16 |
-
def __init__(self, *args, **kwargs):
|
17 |
-
pass
|
18 |
-
|
19 |
-
fairseq = types.ModuleType("fairseq")
|
20 |
-
fairseq_data = types.ModuleType("fairseq.data")
|
21 |
-
fairseq_data_dictionary = types.ModuleType("fairseq.data.dictionary")
|
22 |
-
fairseq_data_dictionary.Dictionary = Dictionary
|
23 |
-
fairseq.data = fairseq_data
|
24 |
-
fairseq_data.dictionary = fairseq_data_dictionary
|
25 |
-
|
26 |
-
sys.modules["fairseq"] = fairseq
|
27 |
-
sys.modules["fairseq.data"] = fairseq_data
|
28 |
-
sys.modules["fairseq.data.dictionary"] = fairseq_data_dictionary
|
29 |
-
|
30 |
-
def load_model(filename):
|
31 |
-
state = torch.load(filename, map_location="cpu")
|
32 |
-
|
33 |
-
model = HubertModel(HubertConfig(**state['cfg']['model']))
|
34 |
-
model.load_state_dict(state['model'], strict=False)
|
35 |
-
|
36 |
-
return [model], Model_Config(state["cfg"]), Model_Config(state["cfg"]["task"])
|
37 |
-
|
38 |
-
def softmax(x, dim, onnx_trace = False):
|
39 |
-
return F.softmax(x.float(), dim=dim) if onnx_trace else F.softmax(x, dim=dim, dtype=torch.float32)
|
40 |
-
|
41 |
-
def log_softmax(x, dim, onnx_trace = False):
|
42 |
-
return F.log_softmax(x.float(), dim=dim) if onnx_trace else F.log_softmax(x, dim=dim, dtype=torch.float32)
|
43 |
-
|
44 |
-
def eval_str_dict(x, type=dict):
|
45 |
-
if x is None: return None
|
46 |
-
if isinstance(x, str): x = eval(x)
|
47 |
-
return x
|
48 |
-
|
49 |
-
def with_incremental_state(cls):
|
50 |
-
cls.__bases__ = (FairseqIncrementalState,) + tuple(b for b in cls.__bases__ if b != FairseqIncrementalState)
|
51 |
-
return cls
|
52 |
-
|
53 |
-
def quant_noise(module, p, block_size):
|
54 |
-
if p <= 0: return module
|
55 |
-
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
56 |
-
|
57 |
-
is_conv = module.weight.ndim == 4
|
58 |
-
if not is_conv: assert (module.weight.size(1) % block_size == 0)
|
59 |
-
else:
|
60 |
-
if module.kernel_size == (1, 1): assert (module.in_channels % block_size == 0)
|
61 |
-
else:
|
62 |
-
k = module.kernel_size[0] * module.kernel_size[1]
|
63 |
-
assert k % block_size == 0
|
64 |
-
|
65 |
-
def _forward_pre_hook(mod, input):
|
66 |
-
if mod.training:
|
67 |
-
if not is_conv:
|
68 |
-
weight = mod.weight
|
69 |
-
in_features = weight.size(1)
|
70 |
-
out_features = weight.size(0)
|
71 |
-
|
72 |
-
mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
|
73 |
-
mask.bernoulli_(p)
|
74 |
-
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
75 |
-
else:
|
76 |
-
weight = mod.weight
|
77 |
-
in_channels = mod.in_channels
|
78 |
-
out_channels = mod.out_channels
|
79 |
-
|
80 |
-
if mod.kernel_size == (1, 1):
|
81 |
-
mask = torch.zeros(int(in_channels // block_size * out_channels), device=weight.device)
|
82 |
-
mask.bernoulli_(p)
|
83 |
-
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
84 |
-
else:
|
85 |
-
mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
|
86 |
-
mask.bernoulli_(p)
|
87 |
-
mask = (mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]))
|
88 |
-
|
89 |
-
mask = mask.to(torch.bool)
|
90 |
-
s = 1 / (1 - p)
|
91 |
-
mod.weight.data = s * weight.masked_fill(mask, 0)
|
92 |
-
|
93 |
-
module.register_forward_pre_hook(_forward_pre_hook)
|
94 |
-
return module
|
95 |
-
|
96 |
-
class FairseqDropout(nn.Module):
|
97 |
-
def __init__(self, p, module_name=None):
|
98 |
-
super().__init__()
|
99 |
-
self.p = p
|
100 |
-
self.module_name = module_name
|
101 |
-
self.apply_during_inference = False
|
102 |
-
|
103 |
-
def forward(self, x, inplace = False):
|
104 |
-
return F.dropout(x, p=self.p, training=True, inplace=inplace) if self.p > 0 and (self.training or self.apply_during_inference) else x
|
105 |
-
|
106 |
-
def make_generation_fast_(self, name, retain_dropout = False, retain_dropout_modules = None, **kwargs):
|
107 |
-
if retain_dropout:
|
108 |
-
if (retain_dropout_modules is None or self.module_name in retain_dropout_modules): self.apply_during_inference = True
|
109 |
-
|
110 |
-
class FairseqIncrementalState(object):
|
111 |
-
def __init__(self, *args, **kwargs):
|
112 |
-
super().__init__(*args, **kwargs)
|
113 |
-
self.init_incremental_state()
|
114 |
-
|
115 |
-
def init_incremental_state(self):
|
116 |
-
self._incremental_state_id = str(uuid.uuid4())
|
117 |
-
|
118 |
-
def _get_full_incremental_state_key(self, key):
|
119 |
-
return "{}.{}".format(self._incremental_state_id, key)
|
120 |
-
|
121 |
-
def get_incremental_state(self, incremental_state, key):
|
122 |
-
full_key = self._get_full_incremental_state_key(key)
|
123 |
-
if incremental_state is None or full_key not in incremental_state: return None
|
124 |
-
return incremental_state[full_key]
|
125 |
-
|
126 |
-
def set_incremental_state(self, incremental_state, key, value):
|
127 |
-
if incremental_state is not None: incremental_state[self._get_full_incremental_state_key(key)] = value
|
128 |
-
return incremental_state
|
129 |
-
|
130 |
-
class FairseqDecoder(nn.Module):
|
131 |
-
def __init__(self, dictionary):
|
132 |
-
super().__init__()
|
133 |
-
self.dictionary = dictionary
|
134 |
-
self.onnx_trace = False
|
135 |
-
self.adaptive_softmax = None
|
136 |
-
|
137 |
-
def forward(self, prev_output_tokens, encoder_out=None, **kwargs):
|
138 |
-
x, extra = self.extract_features(prev_output_tokens, encoder_out=encoder_out, **kwargs)
|
139 |
-
return self.output_layer(x), extra
|
140 |
-
|
141 |
-
def extract_features(self, prev_output_tokens, encoder_out=None, **kwargs):
|
142 |
-
pass
|
143 |
-
|
144 |
-
def output_layer(self, features, **kwargs):
|
145 |
-
pass
|
146 |
-
|
147 |
-
def get_normalized_probs(self, net_output, log_probs, sample = None):
|
148 |
-
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
|
149 |
-
|
150 |
-
def get_normalized_probs_scriptable(self, net_output, log_probs, sample = None):
|
151 |
-
if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
|
152 |
-
if sample is not None:
|
153 |
-
assert "target" in sample
|
154 |
-
target = sample["target"]
|
155 |
-
else: target = None
|
156 |
-
|
157 |
-
out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
|
158 |
-
return out.exp_() if not log_probs else out
|
159 |
-
|
160 |
-
logits = net_output[0]
|
161 |
-
return log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) if log_probs else softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
|
162 |
-
|
163 |
-
def max_positions(self):
|
164 |
-
return 1e6
|
165 |
-
|
166 |
-
def upgrade_state_dict_named(self, state_dict, name):
|
167 |
-
return state_dict
|
168 |
-
|
169 |
-
def prepare_for_onnx_export_(self):
|
170 |
-
self.onnx_trace = True
|
171 |
-
|
172 |
-
@with_incremental_state
|
173 |
-
class FairseqIncrementalDecoder(FairseqDecoder):
|
174 |
-
def __init__(self, dictionary):
|
175 |
-
super().__init__(dictionary)
|
176 |
-
|
177 |
-
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
|
178 |
-
pass
|
179 |
-
|
180 |
-
def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
|
181 |
-
pass
|
182 |
-
|
183 |
-
def reorder_incremental_state(self, incremental_state, new_order):
|
184 |
-
pass
|
185 |
-
|
186 |
-
def reorder_incremental_state_scripting(self, incremental_state, new_order):
|
187 |
-
for module in self.modules():
|
188 |
-
if hasattr(module, "reorder_incremental_state"):
|
189 |
-
result = module.reorder_incremental_state(incremental_state, new_order)
|
190 |
-
if result is not None: incremental_state = result
|
191 |
-
|
192 |
-
def set_beam_size(self, beam_size):
|
193 |
-
if getattr(self, "_beam_size", -1) != beam_size:
|
194 |
-
seen = set()
|
195 |
-
|
196 |
-
def apply_set_beam_size(module):
|
197 |
-
if (module != self and hasattr(module, "set_beam_size") and module not in seen):
|
198 |
-
seen.add(module)
|
199 |
-
module.set_beam_size(beam_size)
|
200 |
-
|
201 |
-
self.apply(apply_set_beam_size)
|
202 |
-
self._beam_size = beam_size
|
203 |
-
|
204 |
-
class MultiheadAttention(FairseqIncrementalDecoder):
|
205 |
-
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, self_attention=False, encoder_decoder_attention=False, dictionary=None, q_noise=0.0, qn_block_size=8, xformers_att_config=None, xformers_blocksparse_layout=None, xformers_blocksparse_blocksize=16):
|
206 |
-
super().__init__(dictionary)
|
207 |
-
xformers_att_config = eval_str_dict(xformers_att_config)
|
208 |
-
self.use_xformers = xformers_att_config is not None
|
209 |
-
if self.use_xformers: raise ImportError
|
210 |
-
self.embed_dim = embed_dim
|
211 |
-
self.kdim = kdim if kdim is not None else embed_dim
|
212 |
-
self.vdim = vdim if vdim is not None else embed_dim
|
213 |
-
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
214 |
-
self.num_heads = num_heads
|
215 |
-
self.dropout_module = FairseqDropout(dropout, module_name=self.__class__.__name__)
|
216 |
-
self.head_dim = embed_dim // num_heads
|
217 |
-
assert (self.head_dim * num_heads == self.embed_dim)
|
218 |
-
self.scaling = self.head_dim**-0.5
|
219 |
-
self.self_attention = self_attention
|
220 |
-
self.encoder_decoder_attention = encoder_decoder_attention
|
221 |
-
assert not self.self_attention or self.qkv_same_dim
|
222 |
-
self.k_proj = quant_noise(nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size)
|
223 |
-
self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
|
224 |
-
self.q_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
|
225 |
-
self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
|
226 |
-
if add_bias_kv:
|
227 |
-
self.bias_k = nn.Parameter(torch.Tensor(1, 1, embed_dim))
|
228 |
-
self.bias_v = nn.Parameter(torch.Tensor(1, 1, embed_dim))
|
229 |
-
else: self.bias_k = self.bias_v = None
|
230 |
-
self.add_zero_attn = add_zero_attn
|
231 |
-
self.beam_size = 1
|
232 |
-
self.reset_parameters()
|
233 |
-
self.onnx_trace = False
|
234 |
-
self.skip_embed_dim_check = False
|
235 |
-
self.init_incremental_state()
|
236 |
-
|
237 |
-
def prepare_for_onnx_export_(self):
|
238 |
-
self.onnx_trace = True
|
239 |
-
|
240 |
-
def reset_parameters(self):
|
241 |
-
if self.qkv_same_dim:
|
242 |
-
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
243 |
-
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
244 |
-
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
245 |
-
else:
|
246 |
-
nn.init.xavier_uniform_(self.k_proj.weight)
|
247 |
-
nn.init.xavier_uniform_(self.v_proj.weight)
|
248 |
-
nn.init.xavier_uniform_(self.q_proj.weight)
|
249 |
-
|
250 |
-
nn.init.xavier_uniform_(self.out_proj.weight)
|
251 |
-
|
252 |
-
if self.out_proj.bias is not None: nn.init.constant_(self.out_proj.bias, 0.0)
|
253 |
-
if self.bias_k is not None: nn.init.xavier_normal_(self.bias_k)
|
254 |
-
if self.bias_v is not None: nn.init.xavier_normal_(self.bias_v)
|
255 |
-
|
256 |
-
def _get_reserve_head_index(self, num_heads_to_keep: int):
|
257 |
-
k_proj_heads_norm, q_proj_heads_norm, v_proj_heads_norm = [], [], []
|
258 |
-
for i in range(self.num_heads):
|
259 |
-
start_idx = i * self.head_dim
|
260 |
-
end_idx = (i + 1) * self.head_dim
|
261 |
-
k_proj_heads_norm.append(torch.sum(torch.abs(self.k_proj.weight[start_idx:end_idx])).tolist() + torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist())
|
262 |
-
q_proj_heads_norm.append(torch.sum(torch.abs(self.q_proj.weight[start_idx:end_idx])).tolist() + torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist())
|
263 |
-
v_proj_heads_norm.append(torch.sum(torch.abs(self.v_proj.weight[start_idx:end_idx])).tolist() + torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist())
|
264 |
-
|
265 |
-
heads_norm = []
|
266 |
-
for i in range(self.num_heads):
|
267 |
-
heads_norm.append(k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i])
|
268 |
-
|
269 |
-
sorted_head_index = sorted(range(self.num_heads), key=lambda k: heads_norm[k], reverse=True)
|
270 |
-
reserve_head_index = []
|
271 |
-
for i in range(num_heads_to_keep):
|
272 |
-
reserve_head_index.append((sorted_head_index[i] * self.head_dim, (sorted_head_index[i] + 1) * self.head_dim))
|
273 |
-
return reserve_head_index
|
274 |
-
|
275 |
-
def _adaptive_prune_heads(self, reserve_head_index):
|
276 |
-
new_q_weight, new_q_bias, new_k_weight, new_k_bias, new_v_weight, new_v_bias, new_out_proj_weight = [], [], [], [], [], [], []
|
277 |
-
|
278 |
-
for ele in reserve_head_index:
|
279 |
-
start_idx, end_idx = ele
|
280 |
-
new_q_weight.append(self.q_proj.weight[start_idx:end_idx])
|
281 |
-
new_q_bias.append(self.q_proj.bias[start_idx:end_idx])
|
282 |
-
new_k_weight.append(self.k_proj.weight[start_idx:end_idx])
|
283 |
-
new_k_bias.append(self.k_proj.bias[start_idx:end_idx])
|
284 |
-
new_v_weight.append(self.v_proj.weight[start_idx:end_idx])
|
285 |
-
new_v_bias.append(self.v_proj.bias[start_idx:end_idx])
|
286 |
-
new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx])
|
287 |
-
|
288 |
-
new_q_weight = torch.cat(new_q_weight).detach()
|
289 |
-
new_k_weight = torch.cat(new_k_weight).detach()
|
290 |
-
new_v_weight = torch.cat(new_v_weight).detach()
|
291 |
-
new_out_proj_weight = torch.cat(new_out_proj_weight, dim=-1).detach()
|
292 |
-
new_q_weight.requires_grad = True
|
293 |
-
new_k_weight.requires_grad = True
|
294 |
-
new_v_weight.requires_grad = True
|
295 |
-
new_out_proj_weight.requires_grad = True
|
296 |
-
new_q_bias = torch.cat(new_q_bias).detach()
|
297 |
-
new_q_bias.requires_grad = True
|
298 |
-
new_k_bias = torch.cat(new_k_bias).detach()
|
299 |
-
new_k_bias.requires_grad = True
|
300 |
-
new_v_bias = torch.cat(new_v_bias).detach()
|
301 |
-
new_v_bias.requires_grad = True
|
302 |
-
|
303 |
-
self.q_proj.weight = nn.Parameter(new_q_weight)
|
304 |
-
self.q_proj.bias = nn.Parameter(new_q_bias)
|
305 |
-
self.k_proj.weight = nn.Parameter(new_k_weight)
|
306 |
-
self.k_proj.bias = nn.Parameter(new_k_bias)
|
307 |
-
self.v_proj.weight = nn.Parameter(new_v_weight)
|
308 |
-
self.v_proj.bias = nn.Parameter(new_v_bias)
|
309 |
-
self.out_proj.weight = nn.Parameter(new_out_proj_weight)
|
310 |
-
self.num_heads = len(reserve_head_index)
|
311 |
-
self.embed_dim = self.head_dim * self.num_heads
|
312 |
-
self.q_proj.out_features = self.embed_dim
|
313 |
-
self.k_proj.out_features = self.embed_dim
|
314 |
-
self.v_proj.out_features = self.embed_dim
|
315 |
-
|
316 |
-
def _set_skip_embed_dim_check(self):
|
317 |
-
self.skip_embed_dim_check = True
|
318 |
-
|
319 |
-
def _pad_masks(self, key_padding_mask, attn_mask):
|
320 |
-
if attn_mask is not None:
|
321 |
-
shape = attn_mask.size()[:-1] + torch.Size([1])
|
322 |
-
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1)
|
323 |
-
|
324 |
-
if key_padding_mask is not None:
|
325 |
-
shape = key_padding_mask.size()[:-1] + torch.Size([1])
|
326 |
-
key_padding_mask = torch.cat([key_padding_mask, key_padding_mask.new_zeros(shape)], dim=-1)
|
327 |
-
|
328 |
-
return key_padding_mask, attn_mask
|
329 |
-
|
330 |
-
def _add_bias(self, k, v, key_padding_mask, attn_mask, bsz):
|
331 |
-
assert self.bias_k is not None or self.bias_v is not None
|
332 |
-
key_padding_mask, attn_mask = self._pad_masks(key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
333 |
-
return torch.cat([k, self.bias_k.repeat(1, bsz, 1)]), torch.cat([v, self.bias_v.repeat(1, bsz, 1)]), key_padding_mask, attn_mask
|
334 |
-
|
335 |
-
def _append_zero_attn(self, k, v, key_padding_mask, attn_mask):
|
336 |
-
zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:]
|
337 |
-
key_padding_mask, attn_mask = self._pad_masks(key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
338 |
-
return torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=-2), torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=-2), key_padding_mask, attn_mask
|
339 |
-
|
340 |
-
def forward(self, query, key, value, key_padding_mask = None, incremental_state = None, need_weights = True, static_kv = False, attn_mask = None, before_softmax = False, need_head_weights = False):
|
341 |
-
if need_head_weights: need_weights = True
|
342 |
-
is_tpu = query.device.type == "xla"
|
343 |
-
tgt_len, bsz, embed_dim = query.size()
|
344 |
-
src_len = tgt_len
|
345 |
-
|
346 |
-
if not self.skip_embed_dim_check: assert (embed_dim == self.embed_dim)
|
347 |
-
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
348 |
-
|
349 |
-
if key is not None:
|
350 |
-
src_len, key_bsz, _ = key.size()
|
351 |
-
if not torch.jit.is_scripting():
|
352 |
-
assert value is not None
|
353 |
-
assert src_len, key_bsz == value.shape[:2]
|
354 |
-
|
355 |
-
if (not self.onnx_trace and not is_tpu and incremental_state is None and not static_kv and not torch.jit.is_scripting() and not self.skip_embed_dim_check):
|
356 |
-
assert key is not None and value is not None
|
357 |
-
return F.multi_head_attention_forward(query, key, value, self.embed_dim, self.num_heads, torch.empty([0]), torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), self.bias_k, self.bias_v, self.add_zero_attn, self.dropout_module.p, self.out_proj.weight, self.out_proj.bias, self.training or self.dropout_module.apply_during_inference, key_padding_mask.bool() if key_padding_mask is not None else None, need_weights, attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight)
|
358 |
-
|
359 |
-
if incremental_state is not None:
|
360 |
-
saved_state = self._get_input_buffer(incremental_state)
|
361 |
-
if saved_state is not None and "prev_key" in saved_state:
|
362 |
-
if static_kv:
|
363 |
-
assert self.encoder_decoder_attention and not self.self_attention
|
364 |
-
key = value = None
|
365 |
-
else: saved_state = None
|
366 |
-
|
367 |
-
if self.self_attention:
|
368 |
-
q = self.q_proj(query)
|
369 |
-
k = self.k_proj(query)
|
370 |
-
v = self.v_proj(query)
|
371 |
-
elif self.encoder_decoder_attention:
|
372 |
-
q = self.q_proj(query)
|
373 |
-
if key is None:
|
374 |
-
assert value is None
|
375 |
-
k = v = None
|
376 |
-
else:
|
377 |
-
if self.beam_size > 1 and bsz == key.size(1):
|
378 |
-
key = key.view(key.size(0), -1, self.beam_size, key.size(2))[:, :, 0, :]
|
379 |
-
if key_padding_mask is not None: key_padding_mask = key_padding_mask.view(-1, self.beam_size, key_padding_mask.size(1))[:, 0, :]
|
380 |
-
k = self.k_proj(key)
|
381 |
-
v = self.v_proj(key)
|
382 |
-
else:
|
383 |
-
assert key is not None and value is not None
|
384 |
-
q = self.q_proj(query)
|
385 |
-
k = self.k_proj(key)
|
386 |
-
v = self.v_proj(value)
|
387 |
-
|
388 |
-
q *= self.scaling
|
389 |
-
|
390 |
-
if self.bias_k is not None:
|
391 |
-
assert self.bias_v is not None
|
392 |
-
k, v, attn_mask, key_padding_mask = self._add_bias(k, v, attn_mask, key_padding_mask, bsz)
|
393 |
-
|
394 |
-
q = (q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1))
|
395 |
-
kv_bsz = bsz
|
396 |
-
|
397 |
-
if k is not None:
|
398 |
-
kv_bsz = k.size(1)
|
399 |
-
k = (k.contiguous().view(-1, kv_bsz * self.num_heads, self.head_dim).transpose(0, 1))
|
400 |
-
|
401 |
-
if v is not None: v = (v.contiguous().view(-1, kv_bsz * self.num_heads, self.head_dim).transpose(0, 1))
|
402 |
-
|
403 |
-
if saved_state is not None:
|
404 |
-
if "prev_key" in saved_state:
|
405 |
-
_prev_key = saved_state["prev_key"]
|
406 |
-
assert _prev_key is not None
|
407 |
-
|
408 |
-
kv_bsz = _prev_key.size(0)
|
409 |
-
prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim)
|
410 |
-
|
411 |
-
if static_kv: k = prev_key
|
412 |
-
else:
|
413 |
-
assert k is not None
|
414 |
-
k = torch.cat([prev_key, k], dim=1)
|
415 |
-
src_len = k.size(1)
|
416 |
-
|
417 |
-
if "prev_value" in saved_state:
|
418 |
-
_prev_value = saved_state["prev_value"]
|
419 |
-
assert _prev_value is not None or kv_bsz == _prev_value.size(0)
|
420 |
-
prev_value = _prev_value.view(kv_bsz * self.num_heads, -1, self.head_dim)
|
421 |
-
|
422 |
-
if static_kv: v = prev_value
|
423 |
-
else:
|
424 |
-
assert v is not None
|
425 |
-
v = torch.cat([prev_value, v], dim=1)
|
426 |
-
|
427 |
-
prev_key_padding_mask = None
|
428 |
-
if "prev_key_padding_mask" in saved_state: prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
429 |
-
|
430 |
-
assert k is not None and v is not None
|
431 |
-
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(key_padding_mask=key_padding_mask, prev_key_padding_mask=prev_key_padding_mask, batch_size=kv_bsz, src_len=k.size(1), static_kv=static_kv)
|
432 |
-
|
433 |
-
saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim)
|
434 |
-
saved_state["prev_value"] = v.view(kv_bsz, self.num_heads, -1, self.head_dim)
|
435 |
-
saved_state["prev_key_padding_mask"] = key_padding_mask
|
436 |
-
|
437 |
-
assert incremental_state is not None
|
438 |
-
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
439 |
-
|
440 |
-
assert k is not None
|
441 |
-
assert k.size(1) == src_len
|
442 |
-
|
443 |
-
if key_padding_mask is not None and key_padding_mask.dim() == 0: key_padding_mask = None
|
444 |
-
|
445 |
-
if key_padding_mask is not None:
|
446 |
-
assert key_padding_mask.size(0) == kv_bsz
|
447 |
-
assert key_padding_mask.size(1) == src_len
|
448 |
-
|
449 |
-
if self.add_zero_attn:
|
450 |
-
assert v is not None
|
451 |
-
src_len += 1
|
452 |
-
k, v, key_padding_mask, attn_mask = self._append_zero_attn(k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
453 |
-
|
454 |
-
if self.encoder_decoder_attention and bsz != kv_bsz:
|
455 |
-
attn_weights = torch.einsum("bxhtd,bhsd->bxhts", q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]), k.view((kv_bsz, self.num_heads) + k.size()[1:]))
|
456 |
-
attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:])
|
457 |
-
else: attn_weights = torch.bmm(q, k.transpose(1, 2))
|
458 |
-
|
459 |
-
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
460 |
-
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
461 |
-
|
462 |
-
if attn_mask is not None:
|
463 |
-
attn_mask = attn_mask.unsqueeze(0)
|
464 |
-
if self.onnx_trace: attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
|
465 |
-
attn_weights += attn_mask
|
466 |
-
|
467 |
-
if key_padding_mask is not None:
|
468 |
-
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
469 |
-
attn_weights = attn_weights.view(kv_bsz, -1, self.num_heads, tgt_len, src_len).masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(torch.bool), float("-inf")) if not is_tpu else attn_weights.transpose(0, 2).masked_fill(key_padding_mask, float("-inf")).transpose(0, 2)
|
470 |
-
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
471 |
-
|
472 |
-
if before_softmax: return attn_weights, v
|
473 |
-
|
474 |
-
attn_weights_float = softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace)
|
475 |
-
attn_weights = attn_weights_float.type_as(attn_weights)
|
476 |
-
attn_probs = self.dropout_module(attn_weights)
|
477 |
-
|
478 |
-
assert v is not None
|
479 |
-
attn = None
|
480 |
-
|
481 |
-
if self.encoder_decoder_attention and bsz != kv_bsz:
|
482 |
-
attn = torch.einsum("bxhts,bhsd->bxhtd", attn_probs.view((kv_bsz, -1, self.num_heads) + attn_probs.size()[1:]), v.view((kv_bsz, self.num_heads) + v.size()[1:]))
|
483 |
-
attn = attn.reshape((-1,) + attn.size()[-2:])
|
484 |
-
else: attn = torch.bmm(attn_probs, v)
|
485 |
-
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
486 |
-
|
487 |
-
if self.onnx_trace and attn.size(1) == 1: attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim)
|
488 |
-
else: attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
|
489 |
-
|
490 |
-
attn = self.out_proj(attn)
|
491 |
-
attn_weights = None
|
492 |
-
|
493 |
-
if need_weights:
|
494 |
-
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
|
495 |
-
if not need_head_weights: attn_weights = attn_weights.mean(dim=0)
|
496 |
-
|
497 |
-
return attn, attn_weights
|
498 |
-
|
499 |
-
@staticmethod
|
500 |
-
def _append_prev_key_padding_mask(key_padding_mask, prev_key_padding_mask, batch_size, src_len, static_kv):
|
501 |
-
if prev_key_padding_mask is not None and static_kv: new_key_padding_mask = prev_key_padding_mask
|
502 |
-
elif prev_key_padding_mask is not None and key_padding_mask is not None: new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1)
|
503 |
-
elif prev_key_padding_mask is not None:
|
504 |
-
if src_len > prev_key_padding_mask.size(1):
|
505 |
-
filler = torch.zeros((batch_size, src_len - prev_key_padding_mask.size(1)), device=prev_key_padding_mask.device)
|
506 |
-
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
|
507 |
-
else: new_key_padding_mask = prev_key_padding_mask.float()
|
508 |
-
elif key_padding_mask is not None:
|
509 |
-
if src_len > key_padding_mask.size(1):
|
510 |
-
filler = torch.zeros((batch_size, src_len - key_padding_mask.size(1)), device=key_padding_mask.device)
|
511 |
-
new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
|
512 |
-
else: new_key_padding_mask = key_padding_mask.float()
|
513 |
-
else: new_key_padding_mask = prev_key_padding_mask
|
514 |
-
return new_key_padding_mask
|
515 |
-
|
516 |
-
@torch.jit.export
|
517 |
-
def reorder_incremental_state(self, incremental_state, new_order):
|
518 |
-
input_buffer = self._get_input_buffer(incremental_state)
|
519 |
-
if input_buffer is not None:
|
520 |
-
for k in input_buffer.keys():
|
521 |
-
input_buffer_k = input_buffer[k]
|
522 |
-
if input_buffer_k is not None:
|
523 |
-
if self.encoder_decoder_attention:
|
524 |
-
if input_buffer_k.size(0) * self.beam_size == new_order.size(0): return incremental_state
|
525 |
-
elif self.beam_size > 1: input_buffer[k] = input_buffer_k.index_select(0, new_order.reshape(-1, self.beam_size)[:, 0] // self.beam_size)
|
526 |
-
else: input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
527 |
-
else: input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
528 |
-
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
|
529 |
-
return incremental_state
|
530 |
-
|
531 |
-
def set_beam_size(self, beam_size):
|
532 |
-
self.beam_size = beam_size
|
533 |
-
|
534 |
-
def _get_input_buffer(self, incremental_state):
|
535 |
-
result = self.get_incremental_state(incremental_state, "attn_state")
|
536 |
-
if result is not None: return result
|
537 |
-
else: return {}
|
538 |
-
|
539 |
-
def _set_input_buffer(self, incremental_state, buffer):
|
540 |
-
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
541 |
-
|
542 |
-
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
543 |
-
return attn_weights
|
544 |
-
|
545 |
-
def upgrade_state_dict_named(self, state_dict, name):
|
546 |
-
prefix = name + "." if name != "" else ""
|
547 |
-
items_to_add = {}
|
548 |
-
keys_to_remove = []
|
549 |
-
for k in state_dict.keys():
|
550 |
-
if k.endswith(prefix + "in_proj_weight"):
|
551 |
-
dim = int(state_dict[k].shape[0] / 3)
|
552 |
-
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
|
553 |
-
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
|
554 |
-
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
|
555 |
-
keys_to_remove.append(k)
|
556 |
-
k_bias = prefix + "in_proj_bias"
|
557 |
-
if k_bias in state_dict.keys():
|
558 |
-
dim = int(state_dict[k].shape[0] / 3)
|
559 |
-
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
|
560 |
-
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][dim : 2 * dim]
|
561 |
-
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
|
562 |
-
keys_to_remove.append(prefix + "in_proj_bias")
|
563 |
-
|
564 |
-
for k in keys_to_remove:
|
565 |
-
del state_dict[k]
|
566 |
-
|
567 |
-
for key, value in items_to_add.items():
|
568 |
-
state_dict[key] = value
|
569 |
-
|
570 |
-
def init_bert_params(module):
|
571 |
-
def normal_(data):
|
572 |
-
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
573 |
-
|
574 |
-
if isinstance(module, nn.Linear):
|
575 |
-
normal_(module.weight.data)
|
576 |
-
if module.bias is not None: module.bias.data.zero_()
|
577 |
-
if isinstance(module, nn.Embedding):
|
578 |
-
normal_(module.weight.data)
|
579 |
-
if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_()
|
580 |
-
if isinstance(module, MultiheadAttention):
|
581 |
-
normal_(module.q_proj.weight.data)
|
582 |
-
normal_(module.k_proj.weight.data)
|
583 |
-
normal_(module.v_proj.weight.data)
|
584 |
-
|
585 |
-
def make_conv_pos(e, k, g):
|
586 |
-
pos_conv = nn.Conv1d(e, e, kernel_size=k, padding=k // 2, groups=g)
|
587 |
-
dropout = 0
|
588 |
-
|
589 |
-
nn.init.normal_(pos_conv.weight, mean=0, std=math.sqrt((4 * (1.0 - dropout)) / (k * e)))
|
590 |
-
nn.init.constant_(pos_conv.bias, 0)
|
591 |
-
|
592 |
-
return nn.Sequential(nn.utils.parametrizations.weight_norm(pos_conv, name="weight", dim=2), SamePad(k), nn.GELU())
|
593 |
-
|
594 |
-
def is_xla_tensor(tensor):
|
595 |
-
return torch.is_tensor(tensor) and tensor.device.type == "xla"
|
596 |
-
|
597 |
-
def index_put(tensor, indices, value):
|
598 |
-
if is_xla_tensor(tensor):
|
599 |
-
for _ in range(indices.dim(), tensor.dim()):
|
600 |
-
indices = indices.unsqueeze(-1)
|
601 |
-
|
602 |
-
if indices.size(-1) < tensor.size(-1): indices = indices.expand_as(tensor)
|
603 |
-
tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
|
604 |
-
else: tensor[indices] = value
|
605 |
-
|
606 |
-
return tensor
|
607 |
-
|
608 |
-
def pad_to_multiple(x, multiple, dim=-1, value=0):
|
609 |
-
if x is None: return None, 0
|
610 |
-
tsz = x.size(dim)
|
611 |
-
m = tsz / multiple
|
612 |
-
remainder = math.ceil(m) * multiple - tsz
|
613 |
-
if m.is_integer(): return x, 0
|
614 |
-
return F.pad(x, (*((0,) * (-1 - dim) * 2), 0, remainder), value=value), remainder
|
615 |
-
|
616 |
-
def compute_mask_indices(shape, padding_mask, mask_prob, mask_length, mask_type = "static", mask_other = 0.0, min_masks = 0, no_overlap = False, min_space = 0, require_same_masks = True, mask_dropout = 0.0, add_masks = False, seed = None, epoch = None, indices = None, idc_select_ver = 1, num_mask_ver = 2):
|
617 |
-
bsz, all_sz = shape
|
618 |
-
mask = np.full((bsz, all_sz), False)
|
619 |
-
|
620 |
-
if num_mask_ver == 1: all_num_mask = max(min_masks, int(mask_prob * all_sz / float(mask_length) + np.random.rand()))
|
621 |
-
mask_idcs = []
|
622 |
-
|
623 |
-
for i in range(bsz):
|
624 |
-
seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6) if seed is not None and epoch is not None and indices is not None else None
|
625 |
-
rng = np.random.default_rng(seed_i)
|
626 |
-
|
627 |
-
if padding_mask is not None:
|
628 |
-
sz = all_sz - padding_mask[i].long().sum().item()
|
629 |
-
assert sz >= 0, sz
|
630 |
-
else: sz = all_sz
|
631 |
-
|
632 |
-
if num_mask_ver == 1: num_mask = max(min_masks, int(mask_prob * sz / float(mask_length) + np.random.rand())) if padding_mask is not None else all_num_mask
|
633 |
-
elif num_mask_ver == 2: num_mask = max(min_masks, int(mask_prob * sz / float(mask_length) + rng.random()))
|
634 |
-
else: raise ValueError
|
635 |
-
|
636 |
-
if mask_type == "static": lengths = np.full(num_mask, mask_length)
|
637 |
-
elif mask_type == "uniform": lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
638 |
-
elif mask_type == "normal": lengths = [max(1, int(round(x))) for x in rng.normal(mask_length, mask_other, size=num_mask)]
|
639 |
-
elif mask_type == "poisson": lengths = [int(round(x)) for x in rng.poisson(mask_length, size=num_mask)]
|
640 |
-
else: raise Exception
|
641 |
-
|
642 |
-
if sum(lengths) == 0:
|
643 |
-
if mask_type == "static": raise ValueError
|
644 |
-
else: lengths = [min(mask_length, sz - 1)]
|
645 |
-
|
646 |
-
if no_overlap:
|
647 |
-
mask_idc = []
|
648 |
-
|
649 |
-
def arrange(s, e, length, keep_length):
|
650 |
-
span_start = rng.randint(s, e - length)
|
651 |
-
mask_idc.extend(span_start + i for i in range(length))
|
652 |
-
new_parts = []
|
653 |
-
|
654 |
-
if span_start - s - min_space >= keep_length: new_parts.append((s, span_start - min_space + 1))
|
655 |
-
if e - span_start - length - min_space > keep_length: new_parts.append((span_start + length + min_space, e))
|
656 |
-
|
657 |
-
return new_parts
|
658 |
-
|
659 |
-
parts = [(0, sz)]
|
660 |
-
min_length = min(lengths)
|
661 |
-
|
662 |
-
for length in sorted(lengths, reverse=True):
|
663 |
-
lens = np.fromiter((e - s if e - s >= length + min_space else 0 for s, e in parts), np.int32)
|
664 |
-
l_sum = np.sum(lens)
|
665 |
-
if l_sum == 0: break
|
666 |
-
s, e = parts.pop(rng.choice(len(parts), p=lens / np.sum(lens)))
|
667 |
-
parts.extend(arrange(s, e, length, min_length))
|
668 |
-
mask_idc = np.asarray(mask_idc)
|
669 |
-
else:
|
670 |
-
if idc_select_ver == 1:
|
671 |
-
min_len = min(lengths)
|
672 |
-
if sz - min_len <= num_mask: min_len = sz - num_mask - 1
|
673 |
-
mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
|
674 |
-
elif idc_select_ver == 2: mask_idc = rng.choice(sz, num_mask, replace=False)
|
675 |
-
else: raise ValueError
|
676 |
-
|
677 |
-
mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
|
678 |
-
|
679 |
-
mask_idc = np.unique(mask_idc[mask_idc < sz])
|
680 |
-
if len(mask_idc) >= sz: raise ValueError
|
681 |
-
mask_idcs.append(mask_idc)
|
682 |
-
|
683 |
-
target_len = None
|
684 |
-
if require_same_masks: target_len = max([len(m) for m in mask_idcs]) if add_masks else min([len(m) for m in mask_idcs])
|
685 |
-
|
686 |
-
for i, mask_idc in enumerate(mask_idcs):
|
687 |
-
if target_len is not None and len(mask_idc) > target_len: mask_idc = rng.choice(mask_idc, target_len, replace=False)
|
688 |
-
mask[i, mask_idc] = True
|
689 |
-
|
690 |
-
if target_len is not None and len(mask_idc) < target_len:
|
691 |
-
to_mask = rng.choice(np.flatnonzero(~mask[i]), target_len - len(mask_idc), replace=False)
|
692 |
-
mask[i, to_mask] = True
|
693 |
-
|
694 |
-
if mask_dropout > 0:
|
695 |
-
masked = np.flatnonzero(mask[i])
|
696 |
-
mask[i, rng.choice(masked, np.rint(len(masked) * mask_dropout).astype(int), replace=False)] = False
|
697 |
-
|
698 |
-
return mask
|
699 |
-
|
700 |
-
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
|
701 |
-
return nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
702 |
-
|
703 |
-
def prune_state_dict(state_dict, model_cfg):
|
704 |
-
arch = None
|
705 |
-
if model_cfg is not None: arch = (model_cfg._name if isinstance(model_cfg, DictConfig) else getattr(model_cfg, "arch", None))
|
706 |
-
|
707 |
-
if not model_cfg or arch is None or arch == "ptt_transformer": return state_dict
|
708 |
-
|
709 |
-
encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
|
710 |
-
decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
|
711 |
-
|
712 |
-
if not encoder_layers_to_keep and not decoder_layers_to_keep: return state_dict
|
713 |
-
|
714 |
-
def create_pruning_pass(layers_to_keep, layer_name):
|
715 |
-
keep_layers = sorted(int(layer_string) for layer_string in layers_to_keep.split(","))
|
716 |
-
mapping_dict = {}
|
717 |
-
for i in range(len(keep_layers)):
|
718 |
-
mapping_dict[str(keep_layers[i])] = str(i)
|
719 |
-
|
720 |
-
return {"substitution_regex": re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name)), "mapping_dict": mapping_dict}
|
721 |
-
|
722 |
-
pruning_passes = []
|
723 |
-
new_state_dict = {}
|
724 |
-
|
725 |
-
if encoder_layers_to_keep: pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
|
726 |
-
if decoder_layers_to_keep: pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
|
727 |
-
|
728 |
-
for layer_name in state_dict.keys():
|
729 |
-
match = re.search(r"\.layers\.(\d+)\.", layer_name)
|
730 |
-
if not match:
|
731 |
-
new_state_dict[layer_name] = state_dict[layer_name]
|
732 |
-
continue
|
733 |
-
|
734 |
-
original_layer_number = match.group(1)
|
735 |
-
for pruning_pass in pruning_passes:
|
736 |
-
if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass["substitution_regex"].search(layer_name):
|
737 |
-
substitution_match = pruning_pass["substitution_regex"].search(layer_name)
|
738 |
-
new_state_dict[(layer_name[: substitution_match.start(1)] + pruning_pass["mapping_dict"][original_layer_number] + layer_name[substitution_match.end(1) :])] = state_dict[layer_name]
|
739 |
-
|
740 |
-
with open_dict(model_cfg) if isinstance(model_cfg, DictConfig) else contextlib.ExitStack():
|
741 |
-
if hasattr(model_cfg, "encoder_layers_to_keep"): model_cfg.encoder_layers_to_keep = None
|
742 |
-
if hasattr(model_cfg, "decoder_layers_to_keep"): model_cfg.decoder_layers_to_keep = None
|
743 |
-
|
744 |
-
return new_state_dict
|
745 |
-
|
746 |
-
def relu_squared(x):
|
747 |
-
return F.relu(x).pow(2)
|
748 |
-
|
749 |
-
def get_activation_fn(activation):
|
750 |
-
def gelu(x):
|
751 |
-
return nn.functional.gelu(x.float()).type_as(x)
|
752 |
-
|
753 |
-
def gelu_accurate(x):
|
754 |
-
if not hasattr(gelu_accurate, "_a"):
|
755 |
-
gelu_accurate._a = math.sqrt(2 / math.pi)
|
756 |
-
return (0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))))
|
757 |
-
|
758 |
-
if activation == "relu": return F.relu
|
759 |
-
elif activation == "relu_squared": return relu_squared
|
760 |
-
elif activation == "gelu": return gelu
|
761 |
-
elif activation == "gelu_fast": return gelu_accurate
|
762 |
-
elif activation == "gelu_accurate": return gelu_accurate
|
763 |
-
elif activation == "tanh": return torch.tanh
|
764 |
-
elif activation == "linear": return lambda x: x
|
765 |
-
elif activation == "swish": return nn.SiLU
|
766 |
-
else: raise RuntimeError
|
767 |
-
|
768 |
-
class SamePad(nn.Module):
|
769 |
-
def __init__(self, kernel_size, causal=False):
|
770 |
-
super().__init__()
|
771 |
-
if causal: self.remove = kernel_size - 1
|
772 |
-
else: self.remove = 1 if kernel_size % 2 == 0 else 0
|
773 |
-
|
774 |
-
def forward(self, x):
|
775 |
-
if self.remove > 0: x = x[:, :, : -self.remove]
|
776 |
-
return x
|
777 |
-
|
778 |
-
class TransformerSentenceEncoderLayer(nn.Module):
|
779 |
-
def __init__(self, embedding_dim = 768, ffn_embedding_dim = 3072, num_attention_heads = 8, dropout = 0.1, attention_dropout = 0.1, activation_dropout = 0.1, activation_fn = "relu", layer_norm_first = False):
|
780 |
-
super().__init__()
|
781 |
-
self.embedding_dim = embedding_dim
|
782 |
-
self.dropout = dropout
|
783 |
-
self.activation_dropout = activation_dropout
|
784 |
-
self.activation_fn = get_activation_fn(activation_fn)
|
785 |
-
self.self_attn = MultiheadAttention(self.embedding_dim, num_attention_heads, dropout=attention_dropout, self_attention=True)
|
786 |
-
self.dropout1 = nn.Dropout(dropout)
|
787 |
-
self.dropout2 = nn.Dropout(self.activation_dropout)
|
788 |
-
self.dropout3 = nn.Dropout(dropout)
|
789 |
-
self.layer_norm_first = layer_norm_first
|
790 |
-
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
791 |
-
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
792 |
-
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
793 |
-
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
794 |
-
|
795 |
-
def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None):
|
796 |
-
residual = x
|
797 |
-
|
798 |
-
if self.layer_norm_first:
|
799 |
-
x = self.self_attn_layer_norm(x)
|
800 |
-
x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, attn_mask=self_attn_mask, need_weights=False)
|
801 |
-
x = residual + self.dropout1(x)
|
802 |
-
residual = x
|
803 |
-
x = self.fc2(self.dropout2(self.activation_fn(self.fc1(self.final_layer_norm(x)))))
|
804 |
-
layer_result = x
|
805 |
-
x = residual + self.dropout3(x)
|
806 |
-
else:
|
807 |
-
x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, need_weights=False)
|
808 |
-
x = self.self_attn_layer_norm(residual + self.dropout1(x))
|
809 |
-
residual = x
|
810 |
-
x = self.fc2(self.dropout2(self.activation_fn(self.fc1(x))))
|
811 |
-
layer_result = x
|
812 |
-
x = self.final_layer_norm(residual + self.dropout3(x))
|
813 |
-
|
814 |
-
return x, (attn, layer_result)
|
815 |
-
|
816 |
-
class AdapterFast(nn.Module):
|
817 |
-
def __init__(self, adapter_num, input_dim, hidden_dim, act_fn):
|
818 |
-
super().__init__()
|
819 |
-
self.adapter_num = adapter_num
|
820 |
-
self.input_dim = input_dim
|
821 |
-
self.hidden_dim = hidden_dim
|
822 |
-
self.W_a = nn.Parameter(torch.empty(adapter_num, hidden_dim, input_dim))
|
823 |
-
self.W_b = nn.Parameter(torch.empty(adapter_num, input_dim, hidden_dim))
|
824 |
-
self.b_a = nn.Parameter(torch.empty(adapter_num, hidden_dim))
|
825 |
-
self.b_b = nn.Parameter(torch.empty(adapter_num, input_dim))
|
826 |
-
self.ln_W = nn.Parameter(torch.empty(adapter_num, input_dim))
|
827 |
-
self.ln_b = nn.Parameter(torch.empty(adapter_num, input_dim))
|
828 |
-
self.act_fn = nn.Identity()
|
829 |
-
if act_fn == "relu": self.act_fn = nn.ReLU()
|
830 |
-
elif act_fn == "gelu": self.act_fn = nn.GELU()
|
831 |
-
elif act_fn == "selu": self.act_fn = nn.SELU()
|
832 |
-
else: raise ValueError
|
833 |
-
|
834 |
-
self.input_dim = input_dim
|
835 |
-
self.reset_parameters()
|
836 |
-
|
837 |
-
def reset_parameters(self):
|
838 |
-
for ii in range(self.adapter_num):
|
839 |
-
nn.init.kaiming_uniform_(self.W_a[ii], a=math.sqrt(5))
|
840 |
-
nn.init.kaiming_uniform_(self.W_b[ii], a=math.sqrt(5))
|
841 |
-
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_a[ii])
|
842 |
-
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
843 |
-
nn.init.uniform_(self.b_a[ii], -bound, bound)
|
844 |
-
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_b[ii])
|
845 |
-
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
846 |
-
nn.init.uniform_(self.b_b[ii], -bound, bound)
|
847 |
-
|
848 |
-
nn.init.ones_(self.ln_W)
|
849 |
-
nn.init.zeros_(self.ln_b)
|
850 |
-
|
851 |
-
def forward(self, x, adapter_id):
|
852 |
-
ii = adapter_id
|
853 |
-
return F.linear(self.act_fn(F.linear(F.layer_norm(x, (self.input_dim, ), self.ln_W[ii], self.ln_b[ii]), self.W_a[ii], self.b_a[ii])), self.W_b[ii], self.b_b[ii])
|
854 |
-
|
855 |
-
def extra_repr(self):
|
856 |
-
return ('adapter={}, input_dim={}, hidden_dim={}'.format(self.adapter_num, self.input_dim, self.hidden_dim))
|
857 |
-
|
858 |
-
class FeedForwardModule(nn.Module):
|
859 |
-
def __init__(self, input_feat, hidden_units, dropout1, dropout2, activation_fn="swish", bias=True):
|
860 |
-
super(FeedForwardModule, self).__init__()
|
861 |
-
self.layer_norm = LayerNorm(input_feat)
|
862 |
-
self.w_1 = nn.Linear(input_feat, hidden_units, bias=bias)
|
863 |
-
self.w_2 = nn.Linear(hidden_units, input_feat, bias=bias)
|
864 |
-
self.dropout1 = nn.Dropout(dropout1)
|
865 |
-
self.dropout2 = nn.Dropout(dropout2)
|
866 |
-
self.activation = get_activation_fn(activation_fn)(hidden_units)
|
867 |
-
|
868 |
-
def forward(self, x):
|
869 |
-
return self.dropout2(self.w_2(self.dropout1(self.activation(self.w_1(self.layer_norm(x))))))
|
870 |
-
|
871 |
-
class ConvolutionModule(nn.Module):
|
872 |
-
def __init__(self, embed_dim, channels, depthwise_kernel_size, dropout, activation_fn="swish", bias=False, export=False):
|
873 |
-
super(ConvolutionModule, self).__init__()
|
874 |
-
assert (depthwise_kernel_size - 1) % 2 == 0
|
875 |
-
self.layer_norm = LayerNorm(embed_dim, export=export)
|
876 |
-
self.pointwise_conv1 = nn.Conv1d(embed_dim, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias)
|
877 |
-
self.glu = nn.GLU(dim=1)
|
878 |
-
self.depthwise_conv = nn.Conv1d(channels, channels, depthwise_kernel_size, stride=1, padding=(depthwise_kernel_size - 1) // 2, groups=channels, bias=bias)
|
879 |
-
self.batch_norm = nn.BatchNorm1d(channels)
|
880 |
-
self.activation = get_activation_fn(activation_fn)(channels)
|
881 |
-
self.pointwise_conv2 = nn.Conv1d(channels, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias)
|
882 |
-
self.dropout = nn.Dropout(dropout)
|
883 |
-
|
884 |
-
def forward(self, x):
|
885 |
-
return self.dropout(self.pointwise_conv2(self.activation(self.batch_norm(self.depthwise_conv(self.glu(self.pointwise_conv1(self.layer_norm(x).transpose(1, 2)))))))).transpose(1, 2)
|
886 |
-
|
887 |
-
def rotate_half(x):
|
888 |
-
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
889 |
-
return torch.cat((-x2, x1), dim=x1.ndim - 1)
|
890 |
-
|
891 |
-
def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
|
892 |
-
cos, sin = (cos[offset : q.shape[0] + offset, ...], sin[offset : q.shape[0] + offset, ...])
|
893 |
-
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
894 |
-
|
895 |
-
class RotaryPositionalEmbedding(nn.Module):
|
896 |
-
def __init__(self, dim, base=10000, precision=torch.half):
|
897 |
-
super().__init__()
|
898 |
-
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
899 |
-
self.register_buffer("inv_freq", inv_freq)
|
900 |
-
self.seq_len_cached = 0
|
901 |
-
self.cos_cached = torch.empty(self.seq_len_cached, 1, 1, dim)
|
902 |
-
self.sin_cached = torch.empty(self.seq_len_cached, 1, 1, dim)
|
903 |
-
self.precision = precision
|
904 |
-
|
905 |
-
def forward(self, x, seq_len = 0):
|
906 |
-
if seq_len > self.seq_len_cached:
|
907 |
-
self.seq_len_cached = seq_len
|
908 |
-
freqs = torch.einsum("i,j->ij", torch.arange(seq_len, device=x.device).type_as(self.inv_freq), self.inv_freq)
|
909 |
-
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
910 |
-
self.cos_cached = emb.cos().view(emb.size(0), 1, 1, emb.size(1))
|
911 |
-
self.sin_cached = emb.sin().view(emb.size(0), 1, 1, emb.size(1))
|
912 |
-
return self.cos_cached, self.sin_cached
|
913 |
-
|
914 |
-
class ESPNETMultiHeadedAttention(nn.Module):
|
915 |
-
def __init__(self, n_feat, n_head, dropout):
|
916 |
-
super(ESPNETMultiHeadedAttention, self).__init__()
|
917 |
-
assert n_feat % n_head == 0
|
918 |
-
self.d_k = n_feat // n_head
|
919 |
-
self.h = n_head
|
920 |
-
self.linear_q = nn.Linear(n_feat, n_feat)
|
921 |
-
self.linear_k = nn.Linear(n_feat, n_feat)
|
922 |
-
self.linear_v = nn.Linear(n_feat, n_feat)
|
923 |
-
self.linear_out = nn.Linear(n_feat, n_feat)
|
924 |
-
self.attn = None
|
925 |
-
self.dropout = nn.Dropout(p=dropout)
|
926 |
-
|
927 |
-
def forward_qkv(self, query, key, value, **kwargs):
|
928 |
-
n_batch = query.size(0)
|
929 |
-
return self.linear_q(query).view(n_batch, -1, self.h, self.d_k).transpose(1, 2), self.linear_k(key).view(n_batch, -1, self.h, self.d_k).transpose(1, 2), self.linear_v(value).view(n_batch, -1, self.h, self.d_k).transpose(1, 2)
|
930 |
-
|
931 |
-
def forward_attention(self, value, scores, mask):
|
932 |
-
n_batch = value.size(0)
|
933 |
-
|
934 |
-
if mask is not None:
|
935 |
-
scores = scores.masked_fill(mask.unsqueeze(1).unsqueeze(2).to(bool), float("-inf"))
|
936 |
-
self.attn = torch.softmax(scores, dim=-1)
|
937 |
-
else: self.attn = torch.softmax(scores, dim=-1)
|
938 |
-
|
939 |
-
return self.linear_out((torch.matmul(self.dropout(self.attn), value).transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)))
|
940 |
-
|
941 |
-
def forward(self, query, key, value, key_padding_mask=None, **kwargs):
|
942 |
-
q, k, v = self.forward_qkv(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1))
|
943 |
-
return self.forward_attention(v, torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k), key_padding_mask).transpose(0, 1), None
|
944 |
-
|
945 |
-
class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
|
946 |
-
def __init__(self, n_feat, n_head, dropout, zero_triu=False):
|
947 |
-
super().__init__(n_feat, n_head, dropout)
|
948 |
-
self.zero_triu = zero_triu
|
949 |
-
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
950 |
-
self.pos_bias_u = nn.Parameter(torch.zeros(self.h, self.d_k))
|
951 |
-
self.pos_bias_v = nn.Parameter(torch.zeros(self.h, self.d_k))
|
952 |
-
nn.init.xavier_uniform_(self.pos_bias_u)
|
953 |
-
nn.init.xavier_uniform_(self.pos_bias_v)
|
954 |
-
|
955 |
-
def rel_shift(self, x):
|
956 |
-
x = torch.cat([torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype), x], dim=-1).view(*x.size()[:2], x.size(3) + 1, x.size(2))[:, :, 1:].view_as(x)[:, :, :, : x.size(-1) // 2 + 1]
|
957 |
-
if self.zero_triu: x = x * torch.tril(torch.ones((x.size(2), x.size(3)), device=x.device), x.size(3) - x.size(2))[None, None, :, :]
|
958 |
-
return x
|
959 |
-
|
960 |
-
def forward(self, query, key, value, pos_emb, key_padding_mask=None, **kwargs):
|
961 |
-
pos_emb = pos_emb.transpose(0, 1)
|
962 |
-
q, k, v = self.forward_qkv(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1))
|
963 |
-
q = q.transpose(1, 2)
|
964 |
-
|
965 |
-
return self.forward_attention(v, (torch.matmul((q + self.pos_bias_u).transpose(1, 2), k.transpose(-2, -1)) + self.rel_shift(torch.matmul((q + self.pos_bias_v).transpose(1, 2), self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.h, self.d_k).transpose(1, 2).transpose(-2, -1)))) / math.sqrt(self.d_k), key_padding_mask).transpose(0, 1), None
|
966 |
-
|
967 |
-
class RotaryPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
|
968 |
-
def __init__(self, n_feat, n_head, dropout, precision, rotary_emd_base=10000):
|
969 |
-
super().__init__(n_feat, n_head, dropout)
|
970 |
-
precision = torch.float
|
971 |
-
self.rotary_ndims = self.d_k
|
972 |
-
if precision == "fp16": precision = torch.half
|
973 |
-
self.rotary_emb = RotaryPositionalEmbedding(self.rotary_ndims, base=rotary_emd_base, precision=precision)
|
974 |
-
|
975 |
-
def forward(self, query, key, value, key_padding_mask=None, **kwargs):
|
976 |
-
T, B, C = value.size()
|
977 |
-
query = query.view(T, B, self.h, self.d_k)
|
978 |
-
key = key.view(T, B, self.h, self.d_k)
|
979 |
-
value = value.view(T, B, self.h, self.d_k)
|
980 |
-
|
981 |
-
cos, sin = self.rotary_emb(value, seq_len=T)
|
982 |
-
query, key = apply_rotary_pos_emb(query, key, cos, sin, offset=0)
|
983 |
-
|
984 |
-
query = query.view(T, B, self.h * self.d_k)
|
985 |
-
key = key.view(T, B, self.h * self.d_k)
|
986 |
-
value = value.view(T, B, self.h * self.d_k)
|
987 |
-
|
988 |
-
q, k, v = self.forward_qkv(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1))
|
989 |
-
return self.forward_attention(v, torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k), key_padding_mask).transpose(0, 1), None
|
990 |
-
|
991 |
-
class ConformerEncoderLayer(nn.Module):
|
992 |
-
def __init__(self, embed_dim, ffn_embed_dim, attention_heads, dropout, use_fp16, depthwise_conv_kernel_size=31, activation_fn="swish", attn_type=None, pos_enc_type="abs"):
|
993 |
-
self.pos_enc_type = pos_enc_type
|
994 |
-
super(ConformerEncoderLayer, self).__init__()
|
995 |
-
self.ffn1 = FeedForwardModule(embed_dim, ffn_embed_dim, dropout, dropout)
|
996 |
-
self.self_attn_layer_norm = LayerNorm(embed_dim, export=False)
|
997 |
-
self.self_attn_dropout = nn.Dropout(dropout)
|
998 |
-
|
999 |
-
if attn_type == "espnet":
|
1000 |
-
if self.pos_enc_type == "rel_pos": self.self_attn = RelPositionMultiHeadedAttention(embed_dim, attention_heads, dropout=dropout)
|
1001 |
-
elif self.pos_enc_type == "rope": self.self_attn = RotaryPositionMultiHeadedAttention(embed_dim, attention_heads, dropout=dropout, precision=use_fp16)
|
1002 |
-
elif self.pos_enc_type == "abs": self.self_attn = ESPNETMultiHeadedAttention(embed_dim, attention_heads, dropout=dropout)
|
1003 |
-
else: raise Exception
|
1004 |
-
else: self.self_attn = MultiheadAttention(embed_dim, attention_heads, dropout=dropout)
|
1005 |
-
|
1006 |
-
self.conv_module = ConvolutionModule(embed_dim=embed_dim, channels=embed_dim, depthwise_kernel_size=depthwise_conv_kernel_size, dropout=dropout, activation_fn=activation_fn)
|
1007 |
-
self.ffn2 = FeedForwardModule(embed_dim, ffn_embed_dim, dropout, dropout, activation_fn=activation_fn)
|
1008 |
-
self.final_layer_norm = LayerNorm(embed_dim, export=False)
|
1009 |
-
|
1010 |
-
def forward(self, x, encoder_padding_mask, position_emb = None):
|
1011 |
-
residual = x
|
1012 |
-
x = self.ffn1(x) * 0.5 + residual
|
1013 |
-
residual = x
|
1014 |
-
x = self.self_attn_layer_norm(x)
|
1015 |
-
|
1016 |
-
if self.pos_enc_type == "rel_pos": x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, pos_emb=position_emb, need_weights=False)
|
1017 |
-
else: x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, need_weights=False)
|
1018 |
-
|
1019 |
-
x = self.self_attn_dropout(x)
|
1020 |
-
x = x + residual
|
1021 |
-
residual = x
|
1022 |
-
x = residual + self.conv_module(x.transpose(0, 1)).transpose(0, 1)
|
1023 |
-
residual = x
|
1024 |
-
x = self.ffn2(x)
|
1025 |
-
layer_result = x
|
1026 |
-
x = self.final_layer_norm(x * 0.5 + residual)
|
1027 |
-
|
1028 |
-
return x, (attn, layer_result)
|
1029 |
-
|
1030 |
-
class ConformerWav2Vec2EncoderLayer(ConformerEncoderLayer):
|
1031 |
-
def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None, position_emb=None):
|
1032 |
-
return super().forward(x, self_attn_padding_mask, position_emb)
|
1033 |
-
|
1034 |
-
class TransformerSentenceEncoderWithAdapterLayer(TransformerSentenceEncoderLayer):
|
1035 |
-
def __init__(self, embedding_dim = 768, ffn_embedding_dim = 3072, num_attention_heads = 8, dropout = 0.1, attention_dropout = 0.1, activation_dropout = 0.1, activation_fn = "relu", layer_norm_first = False, adapter_num=201, adapter_dim=64, adapter_act_fn="relu"):
|
1036 |
-
super().__init__(embedding_dim=embedding_dim, ffn_embedding_dim=ffn_embedding_dim, num_attention_heads=num_attention_heads, dropout=dropout, attention_dropout=attention_dropout, activation_dropout=activation_dropout, activation_fn=activation_fn, layer_norm_first=layer_norm_first)
|
1037 |
-
self.adapter_num = adapter_num
|
1038 |
-
self.adapter_dim = adapter_dim
|
1039 |
-
self.adapter_layer = AdapterFast(adapter_num, self.embedding_dim, self.adapter_dim, adapter_act_fn)
|
1040 |
-
|
1041 |
-
def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None, corpus_key=None):
|
1042 |
-
|
1043 |
-
x, (attn, layer_result) = super().forward(x=x, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, need_weights=need_weights, att_args=att_args)
|
1044 |
-
assert corpus_key is not None
|
1045 |
-
assert len(set(corpus_key)) == 1
|
1046 |
-
|
1047 |
-
return x + self.adapter_layer(x, corpus_key[0]), (attn, layer_result)
|
1048 |
-
|
1049 |
-
class TransposeLast(nn.Module):
|
1050 |
-
def __init__(self, deconstruct_idx=None, tranpose_dim=-2):
|
1051 |
-
super().__init__()
|
1052 |
-
self.deconstruct_idx = deconstruct_idx
|
1053 |
-
self.tranpose_dim = tranpose_dim
|
1054 |
-
|
1055 |
-
def forward(self, x):
|
1056 |
-
if self.deconstruct_idx is not None: x = x[self.deconstruct_idx]
|
1057 |
-
return x.transpose(self.tranpose_dim, -1)
|
1058 |
-
|
1059 |
-
class TransformerEncoder(nn.Module):
|
1060 |
-
def build_encoder_layer(self, args, **kwargs):
|
1061 |
-
if args.layer_type == "transformer": layer = TransformerSentenceEncoderLayer(embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first)
|
1062 |
-
elif args.layer_type == "conformer": layer = ConformerWav2Vec2EncoderLayer(embed_dim=self.embedding_dim, ffn_embed_dim=args.encoder_ffn_embed_dim, attention_heads=args.encoder_attention_heads, dropout=args.dropout, depthwise_conv_kernel_size=args.depthwise_conv_kernel_size, activation_fn="swish", attn_type=args.attn_type, use_fp16=args.fp16, pos_enc_type="abs")
|
1063 |
-
elif args.layer_type == "trf_adp":
|
1064 |
-
use_adp = False
|
1065 |
-
if args.adp_trf_idx == "all": use_adp = True
|
1066 |
-
else:
|
1067 |
-
if kwargs.get("layer_idx", None) in list(range(*[int(g) for g in args.adp_trf_idx.split(":")])): use_adp = True
|
1068 |
-
|
1069 |
-
layer = TransformerSentenceEncoderWithAdapterLayer(embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first, adapter_num=args.adp_num, adapter_dim=args.adp_dim, adapter_act_fn=args.adp_act_fn) if use_adp else TransformerSentenceEncoderLayer(embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first,)
|
1070 |
-
|
1071 |
-
return layer
|
1072 |
-
|
1073 |
-
def __init__(self, args):
|
1074 |
-
super().__init__()
|
1075 |
-
self.dropout = args.dropout
|
1076 |
-
self.embedding_dim = args.encoder_embed_dim
|
1077 |
-
self.required_seq_len_multiple = args.required_seq_len_multiple
|
1078 |
-
pos_conv_depth = getattr(args, "pos_conv_depth", 1)
|
1079 |
-
|
1080 |
-
if pos_conv_depth > 1:
|
1081 |
-
num_layers = args.pos_conv_depth
|
1082 |
-
k = max(3, args.conv_pos // num_layers)
|
1083 |
-
|
1084 |
-
def make_conv_block(e, k, g, l):
|
1085 |
-
return nn.Sequential(*[nn.Sequential(nn.Conv1d(e, e, kernel_size=k, padding=k // 2, groups=g), SamePad(k), TransposeLast(), LayerNorm(e, elementwise_affine=False), TransposeLast(), nn.GELU()) for _ in range(l)])
|
1086 |
-
|
1087 |
-
self.pos_conv = make_conv_block(self.embedding_dim, k, args.conv_pos_groups, num_layers)
|
1088 |
-
else: self.pos_conv = make_conv_pos(self.embedding_dim, args.conv_pos, args.conv_pos_groups)
|
1089 |
-
|
1090 |
-
self.layers = nn.ModuleList([self.build_encoder_layer(args, layer_idx=ii) for ii in range(args.encoder_layers)])
|
1091 |
-
self.layer_norm_first = args.layer_norm_first
|
1092 |
-
self.layer_norm = LayerNorm(self.embedding_dim)
|
1093 |
-
self.layerdrop = args.encoder_layerdrop
|
1094 |
-
self.apply(init_bert_params)
|
1095 |
-
|
1096 |
-
def forward(self, x, padding_mask=None, layer=None, corpus_key=None):
|
1097 |
-
x, layer_results = self.extract_features(x, padding_mask, layer, corpus_key=corpus_key)
|
1098 |
-
|
1099 |
-
if self.layer_norm_first and layer is None: x = self.layer_norm(x)
|
1100 |
-
return x, layer_results
|
1101 |
-
|
1102 |
-
def extract_features(self, x, padding_mask=None, tgt_layer=None, min_layer=0, corpus_key=None):
|
1103 |
-
if padding_mask is not None: x = index_put(x, padding_mask, 0)
|
1104 |
-
x = x + self.pos_conv(x.transpose(1, 2)).transpose(1, 2)
|
1105 |
-
|
1106 |
-
if not self.layer_norm_first: x = self.layer_norm(x)
|
1107 |
-
x, pad_length = pad_to_multiple(x, self.required_seq_len_multiple, dim=-2, value=0)
|
1108 |
-
|
1109 |
-
if pad_length > 0 and padding_mask is None:
|
1110 |
-
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
|
1111 |
-
padding_mask[:, -pad_length:] = True
|
1112 |
-
else: padding_mask, _ = pad_to_multiple(padding_mask, self.required_seq_len_multiple, dim=-1, value=True)
|
1113 |
-
|
1114 |
-
x = F.dropout(x, p=self.dropout, training=self.training).transpose(0, 1)
|
1115 |
-
layer_results = []
|
1116 |
-
r = None
|
1117 |
-
|
1118 |
-
for i, layer in enumerate(self.layers):
|
1119 |
-
dropout_probability = np.random.random() if self.layerdrop > 0 else 1
|
1120 |
-
if not self.training or (dropout_probability > self.layerdrop):
|
1121 |
-
layer_check = layer
|
1122 |
-
|
1123 |
-
if (corpus_key is None) or (not isinstance(layer_check, (TransformerSentenceEncoderWithAdapterLayer))): x, (z, lr) = layer(x, self_attn_padding_mask=padding_mask, need_weights=False)
|
1124 |
-
else: x, (z, lr) = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, corpus_key=corpus_key)
|
1125 |
-
|
1126 |
-
if i >= min_layer: layer_results.append((x, z, lr))
|
1127 |
-
if i == tgt_layer:
|
1128 |
-
r = x
|
1129 |
-
break
|
1130 |
-
|
1131 |
-
if r is not None: x = r
|
1132 |
-
x = x.transpose(0, 1)
|
1133 |
-
|
1134 |
-
if pad_length > 0:
|
1135 |
-
x = x[:, :-pad_length]
|
1136 |
-
def undo_pad(a, b, c):
|
1137 |
-
return (a[:-pad_length], b[:-pad_length] if b is not None else b, c[:-pad_length])
|
1138 |
-
|
1139 |
-
layer_results = [undo_pad(*u) for u in layer_results]
|
1140 |
-
|
1141 |
-
return x, layer_results
|
1142 |
-
|
1143 |
-
def max_positions(self):
|
1144 |
-
return self.args.max_positions
|
1145 |
-
|
1146 |
-
def upgrade_state_dict_named(self, state_dict, name):
|
1147 |
-
return state_dict
|
1148 |
-
|
1149 |
-
class Fp32GroupNorm(nn.GroupNorm):
|
1150 |
-
def __init__(self, *args, **kwargs):
|
1151 |
-
super().__init__(*args, **kwargs)
|
1152 |
-
|
1153 |
-
def forward(self, input):
|
1154 |
-
output = F.group_norm(input.float(), self.num_groups, self.weight.float() if self.weight is not None else None, self.bias.float() if self.bias is not None else None, self.eps)
|
1155 |
-
return output.type_as(input)
|
1156 |
-
|
1157 |
-
class Fp32LayerNorm(nn.LayerNorm):
|
1158 |
-
def __init__(self, *args, **kwargs):
|
1159 |
-
super().__init__(*args, **kwargs)
|
1160 |
-
|
1161 |
-
def forward(self, input):
|
1162 |
-
output = F.layer_norm(input.float(), self.normalized_shape, self.weight.float() if self.weight is not None else None, self.bias.float() if self.bias is not None else None, self.eps)
|
1163 |
-
return output.type_as(input)
|
1164 |
-
|
1165 |
-
class ConvFeatureExtractionModel(nn.Module):
|
1166 |
-
def __init__(self, conv_layers, dropout = 0.0, mode = "default", conv_bias = False):
|
1167 |
-
super().__init__()
|
1168 |
-
assert mode in {"default", "layer_norm"}
|
1169 |
-
|
1170 |
-
def block(n_in, n_out, k, stride, is_layer_norm=False, is_group_norm=False, conv_bias=False):
|
1171 |
-
def make_conv():
|
1172 |
-
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
|
1173 |
-
nn.init.kaiming_normal_(conv.weight)
|
1174 |
-
return conv
|
1175 |
-
|
1176 |
-
assert (is_layer_norm and is_group_norm) == False
|
1177 |
-
|
1178 |
-
if is_layer_norm: return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.Sequential(TransposeLast(), Fp32LayerNorm(dim, elementwise_affine=True), TransposeLast()), nn.GELU())
|
1179 |
-
elif is_group_norm: return nn.Sequential(make_conv(), nn.Dropout(p=dropout), Fp32GroupNorm(dim, dim, affine=True), nn.GELU())
|
1180 |
-
else: return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
|
1181 |
-
|
1182 |
-
in_d = 1
|
1183 |
-
self.conv_layers = nn.ModuleList()
|
1184 |
-
for i, cl in enumerate(conv_layers):
|
1185 |
-
assert len(cl) == 3
|
1186 |
-
(dim, k, stride) = cl
|
1187 |
-
self.conv_layers.append(block(in_d, dim, k, stride, is_layer_norm=mode == "layer_norm", is_group_norm=mode == "default" and i == 0, conv_bias=conv_bias))
|
1188 |
-
in_d = dim
|
1189 |
-
|
1190 |
-
def forward(self, x):
|
1191 |
-
x = x.unsqueeze(1)
|
1192 |
-
for conv in self.conv_layers:
|
1193 |
-
x = conv(x)
|
1194 |
-
|
1195 |
-
return x
|
1196 |
-
|
1197 |
-
class GradMultiply(torch.autograd.Function):
|
1198 |
-
@staticmethod
|
1199 |
-
def forward(ctx, x, scale):
|
1200 |
-
ctx.scale = scale
|
1201 |
-
res = x.new(x)
|
1202 |
-
return res
|
1203 |
-
|
1204 |
-
@staticmethod
|
1205 |
-
def backward(ctx, grad):
|
1206 |
-
return grad * ctx.scale, None
|
1207 |
-
|
1208 |
-
class BaseFairseqModel(nn.Module):
|
1209 |
-
def __init__(self):
|
1210 |
-
super().__init__()
|
1211 |
-
self._is_generation_fast = False
|
1212 |
-
|
1213 |
-
def get_targets(self, sample, net_output):
|
1214 |
-
return sample["target"]
|
1215 |
-
|
1216 |
-
def extract_features(self, *args, **kwargs):
|
1217 |
-
return self(*args, **kwargs)
|
1218 |
-
|
1219 |
-
def load_state_dict(self, state_dict, strict=True, model_cfg = None, args = None):
|
1220 |
-
self.upgrade_state_dict(state_dict)
|
1221 |
-
new_state_dict = prune_state_dict(state_dict, model_cfg)
|
1222 |
-
return super().load_state_dict(new_state_dict, strict)
|
1223 |
-
|
1224 |
-
def upgrade_state_dict(self, state_dict):
|
1225 |
-
self.upgrade_state_dict_named(state_dict, "")
|
1226 |
-
|
1227 |
-
def upgrade_state_dict_named(self, state_dict, name):
|
1228 |
-
assert state_dict is not None
|
1229 |
-
|
1230 |
-
def do_upgrade(m, prefix):
|
1231 |
-
if len(prefix) > 0: prefix += "."
|
1232 |
-
|
1233 |
-
for n, c in m.named_children():
|
1234 |
-
name = prefix + n
|
1235 |
-
if hasattr(c, "upgrade_state_dict_named"): c.upgrade_state_dict_named(state_dict, name)
|
1236 |
-
elif hasattr(c, "upgrade_state_dict"): c.upgrade_state_dict(state_dict)
|
1237 |
-
do_upgrade(c, name)
|
1238 |
-
|
1239 |
-
do_upgrade(self, name)
|
1240 |
-
|
1241 |
-
def make_generation_fast_(self, **kwargs):
|
1242 |
-
if self._is_generation_fast: return
|
1243 |
-
self._is_generation_fast = True
|
1244 |
-
|
1245 |
-
def apply_remove_weight_norm(module):
|
1246 |
-
try:
|
1247 |
-
nn.utils.remove_weight_norm(module)
|
1248 |
-
except (AttributeError, ValueError):
|
1249 |
-
return
|
1250 |
-
|
1251 |
-
self.apply(apply_remove_weight_norm)
|
1252 |
-
|
1253 |
-
def apply_make_generation_fast_(module, prefix):
|
1254 |
-
if len(prefix) > 0: prefix += "."
|
1255 |
-
|
1256 |
-
base_func = BaseFairseqModel.make_generation_fast_
|
1257 |
-
for n, m in module.named_modules():
|
1258 |
-
if (m != self and hasattr(m, "make_generation_fast_") and m.make_generation_fast_.__func__ is not base_func): m.make_generation_fast_(name=prefix + n, **kwargs)
|
1259 |
-
|
1260 |
-
apply_make_generation_fast_(self, "")
|
1261 |
-
self.eval()
|
1262 |
-
|
1263 |
-
class HubertConfig:
|
1264 |
-
def __init__(self, _name, label_rate, encoder_layers_1, logit_temp_ctr, num_negatives, cross_sample_negatives, ctr_layers, extractor_mode = "default", encoder_layers = 12, encoder_embed_dim = 768, encoder_ffn_embed_dim = 3072, encoder_attention_heads = 12, activation_fn = "gelu", layer_type = "transformer", dropout = 0.1, attention_dropout = 0.1, activation_dropout = 0.0, encoder_layerdrop = 0.0, dropout_input = 0.0, dropout_features = 0.0, final_dim = 0, untie_final_proj = False, layer_norm_first = False, conv_feature_layers = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", conv_bias = False, logit_temp = 0.1, target_glu = False, feature_grad_mult = 1.0, mask_length = 10, mask_prob = 0.65, mask_selection = "static", mask_other = 0.0, no_mask_overlap = False, mask_min_space = 1, mask_channel_length = 10, mask_channel_prob = 0.0, mask_channel_selection = "static", mask_channel_other = 0.0, no_mask_channel_overlap = False, mask_channel_min_space = 1, conv_pos = 128, conv_pos_groups = 16, conv_pos_batch_norm = False, latent_temp = (2, 0.5, 0.999995), skip_masked = False, skip_nomask = False, checkpoint_activations = False, required_seq_len_multiple = 2, depthwise_conv_kernel_size = 31, attn_type = "", pos_enc_type = "abs", fp16 = False):
|
1265 |
-
self._name = _name
|
1266 |
-
self.label_rate = label_rate
|
1267 |
-
self.encoder_layers_1 = encoder_layers_1
|
1268 |
-
self.logit_temp_ctr = logit_temp_ctr
|
1269 |
-
self.num_negatives = num_negatives
|
1270 |
-
self.cross_sample_negatives = cross_sample_negatives
|
1271 |
-
self.ctr_layers = ctr_layers
|
1272 |
-
self.extractor_mode = extractor_mode
|
1273 |
-
self.encoder_layers = encoder_layers
|
1274 |
-
self.encoder_embed_dim = encoder_embed_dim
|
1275 |
-
self.encoder_ffn_embed_dim = encoder_ffn_embed_dim
|
1276 |
-
self.encoder_attention_heads = encoder_attention_heads
|
1277 |
-
self.activation_fn = activation_fn
|
1278 |
-
self.layer_type = layer_type
|
1279 |
-
self.dropout = dropout
|
1280 |
-
self.attention_dropout = attention_dropout
|
1281 |
-
self.activation_dropout = activation_dropout
|
1282 |
-
self.encoder_layerdrop = encoder_layerdrop
|
1283 |
-
self.dropout_input = encoder_layerdrop
|
1284 |
-
self.dropout_features = dropout_features
|
1285 |
-
self.final_dim = final_dim
|
1286 |
-
self.untie_final_proj = untie_final_proj
|
1287 |
-
self.layer_norm_first = layer_norm_first
|
1288 |
-
self.conv_feature_layers = conv_feature_layers
|
1289 |
-
self.conv_bias = conv_bias
|
1290 |
-
self.logit_temp = logit_temp
|
1291 |
-
self.target_glu = target_glu
|
1292 |
-
self.feature_grad_mult = feature_grad_mult
|
1293 |
-
self.mask_length = mask_length
|
1294 |
-
self.mask_prob = mask_prob
|
1295 |
-
self.mask_selection = mask_selection
|
1296 |
-
self.mask_other = mask_other
|
1297 |
-
self.no_mask_overlap = no_mask_overlap
|
1298 |
-
self.mask_min_space = mask_min_space
|
1299 |
-
self.mask_channel_length = mask_channel_length
|
1300 |
-
self.mask_channel_prob = mask_channel_prob
|
1301 |
-
self.mask_channel_selection = mask_channel_selection
|
1302 |
-
self.mask_channel_other = mask_channel_other
|
1303 |
-
self.no_mask_channel_overlap = no_mask_channel_overlap
|
1304 |
-
self.mask_channel_min_space = mask_channel_min_space
|
1305 |
-
self.conv_pos = conv_pos
|
1306 |
-
self.conv_pos_groups = conv_pos_groups
|
1307 |
-
self.conv_pos_batch_norm = conv_pos_batch_norm
|
1308 |
-
self.latent_temp = latent_temp
|
1309 |
-
self.skip_masked = skip_masked
|
1310 |
-
self.skip_nomask = skip_nomask
|
1311 |
-
self.checkpoint_activations = checkpoint_activations
|
1312 |
-
self.required_seq_len_multiple = required_seq_len_multiple
|
1313 |
-
self.depthwise_conv_kernel_size = depthwise_conv_kernel_size
|
1314 |
-
self.attn_type = attn_type
|
1315 |
-
self.pos_enc_type = pos_enc_type
|
1316 |
-
self.fp16 = fp16
|
1317 |
-
|
1318 |
-
class Model_Config(dict):
|
1319 |
-
def __getattr__(*args):
|
1320 |
-
val = dict.get(*args)
|
1321 |
-
return Model_Config(val) if type(val) is dict else val
|
1322 |
-
|
1323 |
-
__setattr__ = dict.__setitem__
|
1324 |
-
__delattr__ = dict.__delitem__
|
1325 |
-
|
1326 |
-
class HubertModel(BaseFairseqModel):
|
1327 |
-
def __init__(self, cfg):
|
1328 |
-
super().__init__()
|
1329 |
-
feature_enc_layers = eval(cfg.conv_feature_layers)
|
1330 |
-
self.embed = feature_enc_layers[-1][0]
|
1331 |
-
self.feature_extractor = ConvFeatureExtractionModel(conv_layers=feature_enc_layers, dropout=0.0, mode=cfg.extractor_mode, conv_bias=cfg.conv_bias)
|
1332 |
-
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
|
1333 |
-
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / 16000
|
1334 |
-
self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None)
|
1335 |
-
self.mask_prob = cfg.mask_prob
|
1336 |
-
self.mask_selection = cfg.mask_selection
|
1337 |
-
self.mask_other = cfg.mask_other
|
1338 |
-
self.mask_length = cfg.mask_length
|
1339 |
-
self.no_mask_overlap = cfg.no_mask_overlap
|
1340 |
-
self.mask_min_space = cfg.mask_min_space
|
1341 |
-
self.mask_channel_prob = cfg.mask_channel_prob
|
1342 |
-
self.mask_channel_selection = cfg.mask_channel_selection
|
1343 |
-
self.mask_channel_other = cfg.mask_channel_other
|
1344 |
-
self.mask_channel_length = cfg.mask_channel_length
|
1345 |
-
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
1346 |
-
self.mask_channel_min_space = cfg.mask_channel_min_space
|
1347 |
-
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
1348 |
-
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
1349 |
-
self.feature_grad_mult = cfg.feature_grad_mult
|
1350 |
-
self.logit_temp = cfg.logit_temp
|
1351 |
-
self.skip_masked = cfg.skip_masked
|
1352 |
-
self.skip_nomask = cfg.skip_nomask
|
1353 |
-
final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
|
1354 |
-
self.mask_emb = nn.Parameter(torch.FloatTensor(cfg.encoder_embed_dim).uniform_())
|
1355 |
-
self.encoder = TransformerEncoder(cfg)
|
1356 |
-
self.layer_norm = LayerNorm(self.embed)
|
1357 |
-
self.target_glu = None
|
1358 |
-
if cfg.target_glu: self.target_glu = nn.Sequential(nn.Linear(final_dim, final_dim * 2), nn.GLU())
|
1359 |
-
self.untie_final_proj = cfg.untie_final_proj
|
1360 |
-
self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
|
1361 |
-
self.num_classes = [504]
|
1362 |
-
self.label_embs_concat = nn.Parameter(torch.FloatTensor(sum(self.num_classes), final_dim))
|
1363 |
-
nn.init.uniform_(self.label_embs_concat)
|
1364 |
-
|
1365 |
-
def upgrade_state_dict_named(self, state_dict, name):
|
1366 |
-
super().upgrade_state_dict_named(state_dict, name)
|
1367 |
-
return state_dict
|
1368 |
-
|
1369 |
-
def apply_mask(self, x, padding_mask, target_list):
|
1370 |
-
B, T, C = x.shape
|
1371 |
-
if self.mask_prob > 0:
|
1372 |
-
mask_indices = torch.from_numpy(compute_mask_indices((B, T), padding_mask, self.mask_prob, self.mask_length, self.mask_selection, self.mask_other, min_masks=2, no_overlap=self.no_mask_overlap, min_space=self.mask_min_space)).to(x.device)
|
1373 |
-
x[mask_indices] = self.mask_emb
|
1374 |
-
else: mask_indices = None
|
1375 |
-
|
1376 |
-
if self.mask_channel_prob > 0: x[(torch.from_numpy(compute_mask_indices((B, C), None, self.mask_channel_prob, self.mask_channel_length, self.mask_channel_selection, self.mask_channel_other, no_overlap=self.no_mask_channel_overlap, min_space=self.mask_channel_min_space)).to(x.device).unsqueeze(1).expand(-1, T, -1))] = 0
|
1377 |
-
return x, mask_indices
|
1378 |
-
|
1379 |
-
def compute_nce(self, x, pos, negs):
|
1380 |
-
neg_is_pos = (pos == negs).all(-1)
|
1381 |
-
logits = torch.cosine_similarity(x.float(), torch.cat([pos.unsqueeze(0), negs], dim=0).float(), dim=-1).type_as(x)
|
1382 |
-
logits /= self.logit_temp
|
1383 |
-
|
1384 |
-
if neg_is_pos.any(): logits[1:][neg_is_pos] = float("-inf")
|
1385 |
-
return logits.transpose(0, 1)
|
1386 |
-
|
1387 |
-
def forward_features(self, source):
|
1388 |
-
if self.feature_grad_mult > 0:
|
1389 |
-
features = self.feature_extractor(source)
|
1390 |
-
if self.feature_grad_mult != 1.0: features = GradMultiply.apply(features, self.feature_grad_mult)
|
1391 |
-
else:
|
1392 |
-
with torch.no_grad():
|
1393 |
-
features = self.feature_extractor(source)
|
1394 |
-
return features
|
1395 |
-
|
1396 |
-
def forward_targets(self, features, target_list):
|
1397 |
-
feat_tsz = features.size(2)
|
1398 |
-
targ_tsz = min([t.size(1) for t in target_list])
|
1399 |
-
|
1400 |
-
if self.feat2tar_ratio * feat_tsz > targ_tsz:
|
1401 |
-
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
|
1402 |
-
features = features[..., :feat_tsz]
|
1403 |
-
|
1404 |
-
return features, [t[:, (torch.arange(feat_tsz).float() * self.feat2tar_ratio).long()] for t in target_list]
|
1405 |
-
|
1406 |
-
def forward_padding_mask(self, features, padding_mask):
|
1407 |
-
extra = padding_mask.size(1) % features.size(1)
|
1408 |
-
if extra > 0: padding_mask = padding_mask[:, :-extra]
|
1409 |
-
|
1410 |
-
return padding_mask.view(padding_mask.size(0), features.size(1), -1).all(-1)
|
1411 |
-
|
1412 |
-
def forward(self, source, target_list = None, padding_mask = None, mask = True, features_only = False, output_layer = None):
|
1413 |
-
features = self.forward_features(source)
|
1414 |
-
if target_list is not None: features, target_list = self.forward_targets(features, target_list)
|
1415 |
-
|
1416 |
-
features_pen = features.float().pow(2).mean()
|
1417 |
-
|
1418 |
-
features = self.layer_norm(features.transpose(1, 2))
|
1419 |
-
unmasked_features = features.clone()
|
1420 |
-
|
1421 |
-
if padding_mask is not None: padding_mask = self.forward_padding_mask(features, padding_mask)
|
1422 |
-
if self.post_extract_proj is not None: features = self.post_extract_proj(features)
|
1423 |
-
|
1424 |
-
features = self.dropout_input(features)
|
1425 |
-
unmasked_features = self.dropout_features(unmasked_features)
|
1426 |
-
|
1427 |
-
if mask: x, mask_indices = self.apply_mask(features, padding_mask, target_list)
|
1428 |
-
else: x, mask_indices = features, None
|
1429 |
-
|
1430 |
-
x, _ = self.encoder(x, padding_mask=padding_mask, layer=None if output_layer is None else output_layer - 1)
|
1431 |
-
if features_only: return {"x": x, "padding_mask": padding_mask, "features": features}
|
1432 |
-
|
1433 |
-
def compute_pred(proj_x, target, label_embs):
|
1434 |
-
y = torch.index_select(label_embs, 0, target.long())
|
1435 |
-
negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
|
1436 |
-
|
1437 |
-
if self.target_glu:
|
1438 |
-
y = self.target_glu(y)
|
1439 |
-
negs = self.target_glu(negs)
|
1440 |
-
|
1441 |
-
return self.compute_nce(proj_x, y, negs)
|
1442 |
-
|
1443 |
-
label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
|
1444 |
-
|
1445 |
-
if not self.skip_masked:
|
1446 |
-
masked_indices = torch.logical_and(~padding_mask, mask_indices)
|
1447 |
-
proj_x_m = self.final_proj(x[masked_indices])
|
1448 |
-
logit_m_list = [compute_pred(proj_x_m, t[masked_indices], label_embs_list[i]) for i, (proj_x_m, t) in enumerate(zip(proj_x_m.chunk(len(target_list), dim=-1) if self.untie_final_proj else [proj_x_m for _ in range(len(target_list))], target_list))]
|
1449 |
-
else: logit_m_list = [None for _ in target_list]
|
1450 |
-
|
1451 |
-
if not self.skip_nomask:
|
1452 |
-
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
|
1453 |
-
proj_x_u = self.final_proj(x[nomask_indices])
|
1454 |
-
logit_u_list = [compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i]) for i, (proj_x_u, t) in enumerate(zip(proj_x_u.chunk(len(target_list), dim=-1) if self.untie_final_proj else [proj_x_u for _ in range(len(target_list))], target_list))]
|
1455 |
-
else: logit_u_list = [None for _ in target_list]
|
1456 |
-
|
1457 |
-
return {"logit_m_list": logit_m_list, "logit_u_list": logit_u_list, "padding_mask": padding_mask, "features_pen": features_pen}
|
1458 |
-
|
1459 |
-
def extract_features(self, source, padding_mask = None, mask = False, ret_conv = False, output_layer = None):
|
1460 |
-
res = self.forward(source, padding_mask=padding_mask, mask=mask, features_only=True, output_layer=output_layer)
|
1461 |
-
return res["features"] if ret_conv else res["x"], res["padding_mask"]
|
1462 |
-
|
1463 |
-
def get_logits(self, net_output, is_masked=True):
|
1464 |
-
return [x.float() for x in (net_output["logit_m_list"] if is_masked else net_output["logit_u_list"]) if x is not None]
|
1465 |
-
|
1466 |
-
def get_targets(self, net_output, is_masked=True):
|
1467 |
-
return [x.new_zeros(x.size(0), dtype=torch.long) for x in self.get_logits(net_output, is_masked)]
|
1468 |
-
|
1469 |
-
def get_extra_losses(self, net_output):
|
1470 |
-
extra_losses, names = [], []
|
1471 |
-
|
1472 |
-
if "features_pen" in net_output:
|
1473 |
-
extra_losses.append(net_output["features_pen"])
|
1474 |
-
names.append("features_pen")
|
1475 |
-
|
1476 |
-
return extra_losses, names
|
1477 |
-
|
1478 |
-
def remove_pretraining_modules(self):
|
1479 |
-
self.target_glu = None
|
1480 |
-
self.final_proj = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/architectures/mdx_separator.py
DELETED
@@ -1,320 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import onnx
|
4 |
-
import torch
|
5 |
-
import platform
|
6 |
-
import onnx2torch
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
-
import onnxruntime as ort
|
10 |
-
|
11 |
-
from tqdm import tqdm
|
12 |
-
|
13 |
-
sys.path.append(os.getcwd())
|
14 |
-
|
15 |
-
from main.configs.config import Config
|
16 |
-
from main.library.uvr5_separator import spec_utils
|
17 |
-
from main.library.uvr5_separator.common_separator import CommonSeparator
|
18 |
-
|
19 |
-
translations = Config().translations
|
20 |
-
|
21 |
-
class MDXSeparator(CommonSeparator):
|
22 |
-
def __init__(self, common_config, arch_config):
|
23 |
-
super().__init__(config=common_config)
|
24 |
-
self.segment_size = arch_config.get("segment_size")
|
25 |
-
self.overlap = arch_config.get("overlap")
|
26 |
-
self.batch_size = arch_config.get("batch_size", 1)
|
27 |
-
self.hop_length = arch_config.get("hop_length")
|
28 |
-
self.enable_denoise = arch_config.get("enable_denoise")
|
29 |
-
self.logger.debug(translations["mdx_info"].format(batch_size=self.batch_size, segment_size=self.segment_size))
|
30 |
-
self.logger.debug(translations["mdx_info_2"].format(overlap=self.overlap, hop_length=self.hop_length, enable_denoise=self.enable_denoise))
|
31 |
-
self.compensate = self.model_data["compensate"]
|
32 |
-
self.dim_f = self.model_data["mdx_dim_f_set"]
|
33 |
-
self.dim_t = 2 ** self.model_data["mdx_dim_t_set"]
|
34 |
-
self.n_fft = self.model_data["mdx_n_fft_scale_set"]
|
35 |
-
self.config_yaml = self.model_data.get("config_yaml", None)
|
36 |
-
self.logger.debug(f"{translations['mdx_info_3']}: compensate = {self.compensate}, dim_f = {self.dim_f}, dim_t = {self.dim_t}, n_fft = {self.n_fft}")
|
37 |
-
self.logger.debug(f"{translations['mdx_info_3']}: config_yaml = {self.config_yaml}")
|
38 |
-
self.load_model()
|
39 |
-
self.n_bins = 0
|
40 |
-
self.trim = 0
|
41 |
-
self.chunk_size = 0
|
42 |
-
self.gen_size = 0
|
43 |
-
self.stft = None
|
44 |
-
self.primary_source = None
|
45 |
-
self.secondary_source = None
|
46 |
-
self.audio_file_path = None
|
47 |
-
self.audio_file_base = None
|
48 |
-
|
49 |
-
def load_model(self):
|
50 |
-
self.logger.debug(translations["load_model_onnx"])
|
51 |
-
|
52 |
-
if self.segment_size == self.dim_t:
|
53 |
-
ort_session_options = ort.SessionOptions()
|
54 |
-
ort_session_options.log_severity_level = 3 if self.log_level > 10 else 0
|
55 |
-
ort_inference_session = ort.InferenceSession(self.model_path, providers=self.onnx_execution_provider, sess_options=ort_session_options)
|
56 |
-
self.model_run = lambda spek: ort_inference_session.run(None, {"input": spek.cpu().numpy()})[0]
|
57 |
-
self.logger.debug(translations["load_model_onnx_success"])
|
58 |
-
else:
|
59 |
-
self.model_run = onnx2torch.convert(onnx.load(self.model_path)) if platform.system() == 'Windows' else onnx2torch.convert(self.model_path)
|
60 |
-
self.model_run.to(self.torch_device).eval()
|
61 |
-
self.logger.debug(translations["onnx_to_pytorch"])
|
62 |
-
|
63 |
-
def separate(self, audio_file_path):
|
64 |
-
self.audio_file_path = audio_file_path
|
65 |
-
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
|
66 |
-
self.logger.debug(translations["mix"].format(audio_file_path=self.audio_file_path))
|
67 |
-
mix = self.prepare_mix(self.audio_file_path)
|
68 |
-
self.logger.debug(translations["normalization_demix"])
|
69 |
-
mix = spec_utils.normalize(wave=mix, max_peak=self.normalization_threshold)
|
70 |
-
source = self.demix(mix)
|
71 |
-
self.logger.debug(translations["mix_success"])
|
72 |
-
output_files = []
|
73 |
-
self.logger.debug(translations["process_output_file"])
|
74 |
-
|
75 |
-
if not isinstance(self.primary_source, np.ndarray):
|
76 |
-
self.logger.debug(translations["primary_source"])
|
77 |
-
self.primary_source = spec_utils.normalize(wave=source, max_peak=self.normalization_threshold).T
|
78 |
-
|
79 |
-
if not isinstance(self.secondary_source, np.ndarray):
|
80 |
-
self.logger.debug(translations["secondary_source"])
|
81 |
-
raw_mix = self.demix(mix, is_match_mix=True)
|
82 |
-
|
83 |
-
if self.invert_using_spec:
|
84 |
-
self.logger.debug(translations["invert_using_spec"])
|
85 |
-
self.secondary_source = spec_utils.invert_stem(raw_mix, source)
|
86 |
-
else:
|
87 |
-
self.logger.debug(translations["invert_using_spec_2"])
|
88 |
-
self.secondary_source = mix.T - source.T
|
89 |
-
|
90 |
-
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
|
91 |
-
self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")
|
92 |
-
self.logger.info(translations["save_secondary_stem_output_path"].format(stem_name=self.secondary_stem_name, stem_output_path=self.secondary_stem_output_path))
|
93 |
-
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
|
94 |
-
output_files.append(self.secondary_stem_output_path)
|
95 |
-
|
96 |
-
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
|
97 |
-
self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")
|
98 |
-
if not isinstance(self.primary_source, np.ndarray): self.primary_source = source.T
|
99 |
-
|
100 |
-
self.logger.info(translations["save_secondary_stem_output_path"].format(stem_name=self.primary_stem_name, stem_output_path=self.primary_stem_output_path))
|
101 |
-
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
|
102 |
-
output_files.append(self.primary_stem_output_path)
|
103 |
-
|
104 |
-
return output_files
|
105 |
-
|
106 |
-
def initialize_model_settings(self):
|
107 |
-
self.logger.debug(translations["starting_model"])
|
108 |
-
|
109 |
-
self.n_bins = self.n_fft // 2 + 1
|
110 |
-
self.trim = self.n_fft // 2
|
111 |
-
|
112 |
-
self.chunk_size = self.hop_length * (self.segment_size - 1)
|
113 |
-
self.gen_size = self.chunk_size - 2 * self.trim
|
114 |
-
|
115 |
-
self.stft = STFT(self.logger, self.n_fft, self.hop_length, self.dim_f, self.torch_device)
|
116 |
-
|
117 |
-
self.logger.debug(f"{translations['input_info']}: n_fft = {self.n_fft} hop_length = {self.hop_length} dim_f = {self.dim_f}")
|
118 |
-
self.logger.debug(f"{translations['model_settings']}: n_bins = {self.n_bins}, Trim = {self.trim}, chunk_size = {self.chunk_size}, gen_size = {self.gen_size}")
|
119 |
-
|
120 |
-
def initialize_mix(self, mix, is_ckpt=False):
|
121 |
-
self.logger.debug(translations["initialize_mix"].format(is_ckpt=is_ckpt, shape=mix.shape))
|
122 |
-
|
123 |
-
if mix.shape[0] != 2:
|
124 |
-
error_message = translations["!=2"].format(shape=mix.shape[0])
|
125 |
-
self.logger.error(error_message)
|
126 |
-
raise ValueError(error_message)
|
127 |
-
|
128 |
-
if is_ckpt:
|
129 |
-
self.logger.debug(translations["process_check"])
|
130 |
-
pad = self.gen_size + self.trim - (mix.shape[-1] % self.gen_size)
|
131 |
-
self.logger.debug(f"{translations['cache']}: {pad}")
|
132 |
-
|
133 |
-
mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1)
|
134 |
-
|
135 |
-
num_chunks = mixture.shape[-1] // self.gen_size
|
136 |
-
self.logger.debug(translations["shape"].format(shape=mixture.shape, num_chunks=num_chunks))
|
137 |
-
|
138 |
-
mix_waves = [mixture[:, i * self.gen_size : i * self.gen_size + self.chunk_size] for i in range(num_chunks)]
|
139 |
-
else:
|
140 |
-
self.logger.debug(translations["process_no_check"])
|
141 |
-
mix_waves = []
|
142 |
-
n_sample = mix.shape[1]
|
143 |
-
|
144 |
-
pad = self.gen_size - n_sample % self.gen_size
|
145 |
-
self.logger.debug(translations["n_sample_or_pad"].format(n_sample=n_sample, pad=pad))
|
146 |
-
|
147 |
-
mix_p = np.concatenate((np.zeros((2, self.trim)), mix, np.zeros((2, pad)), np.zeros((2, self.trim))), 1)
|
148 |
-
self.logger.debug(f"{translations['shape_2']}: {mix_p.shape}")
|
149 |
-
|
150 |
-
i = 0
|
151 |
-
while i < n_sample + pad:
|
152 |
-
mix_waves.append(np.array(mix_p[:, i : i + self.chunk_size]))
|
153 |
-
|
154 |
-
self.logger.debug(translations["process_part"].format(mix_waves=len(mix_waves), i=i, ii=i + self.chunk_size))
|
155 |
-
i += self.gen_size
|
156 |
-
|
157 |
-
mix_waves_tensor = torch.tensor(mix_waves, dtype=torch.float32).to(self.torch_device)
|
158 |
-
self.logger.debug(translations["mix_waves_to_tensor"].format(shape=mix_waves_tensor.shape))
|
159 |
-
|
160 |
-
return mix_waves_tensor, pad
|
161 |
-
|
162 |
-
def demix(self, mix, is_match_mix=False):
|
163 |
-
self.logger.debug(f"{translations['demix_is_match_mix']}: {is_match_mix}...")
|
164 |
-
self.initialize_model_settings()
|
165 |
-
self.logger.debug(f"{translations['mix_shape']}: {mix.shape}")
|
166 |
-
tar_waves_ = []
|
167 |
-
|
168 |
-
if is_match_mix:
|
169 |
-
chunk_size = self.hop_length * (self.segment_size - 1)
|
170 |
-
overlap = 0.02
|
171 |
-
self.logger.debug(translations["chunk_size_or_overlap"].format(chunk_size=chunk_size, overlap=overlap))
|
172 |
-
else:
|
173 |
-
chunk_size = self.chunk_size
|
174 |
-
overlap = self.overlap
|
175 |
-
self.logger.debug(translations["chunk_size_or_overlap_standard"].format(chunk_size=chunk_size, overlap=overlap))
|
176 |
-
|
177 |
-
gen_size = chunk_size - 2 * self.trim
|
178 |
-
self.logger.debug(f"{translations['calc_size']}: {gen_size}")
|
179 |
-
|
180 |
-
mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, gen_size + self.trim - ((mix.shape[-1]) % gen_size)), dtype="float32")), 1)
|
181 |
-
self.logger.debug(f"{translations['mix_cache']}: {mixture.shape}")
|
182 |
-
|
183 |
-
step = int((1 - overlap) * chunk_size)
|
184 |
-
self.logger.debug(translations["step_or_overlap"].format(step=step, overlap=overlap))
|
185 |
-
|
186 |
-
result = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
|
187 |
-
divider = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
|
188 |
-
|
189 |
-
total = 0
|
190 |
-
total_chunks = (mixture.shape[-1] + step - 1) // step
|
191 |
-
self.logger.debug(f"{translations['all_process_part']}: {total_chunks}")
|
192 |
-
|
193 |
-
for i in tqdm(range(0, mixture.shape[-1], step), ncols=100, unit="f"):
|
194 |
-
total += 1
|
195 |
-
start = i
|
196 |
-
end = min(i + chunk_size, mixture.shape[-1])
|
197 |
-
self.logger.debug(translations["process_part_2"].format(total=total, total_chunks=total_chunks, start=start, end=end))
|
198 |
-
|
199 |
-
chunk_size_actual = end - start
|
200 |
-
window = None
|
201 |
-
|
202 |
-
if overlap != 0:
|
203 |
-
window = np.hanning(chunk_size_actual)
|
204 |
-
window = np.tile(window[None, None, :], (1, 2, 1))
|
205 |
-
self.logger.debug(translations["window"])
|
206 |
-
|
207 |
-
mix_part_ = mixture[:, start:end]
|
208 |
-
|
209 |
-
if end != i + chunk_size:
|
210 |
-
pad_size = (i + chunk_size) - end
|
211 |
-
mix_part_ = np.concatenate((mix_part_, np.zeros((2, pad_size), dtype="float32")), axis=-1)
|
212 |
-
|
213 |
-
mix_waves = torch.tensor([mix_part_], dtype=torch.float32).to(self.torch_device).split(self.batch_size)
|
214 |
-
|
215 |
-
total_batches = len(mix_waves)
|
216 |
-
self.logger.debug(f"{translations['mix_or_batch']}: {total_batches}")
|
217 |
-
|
218 |
-
with torch.no_grad():
|
219 |
-
batches_processed = 0
|
220 |
-
|
221 |
-
for mix_wave in mix_waves:
|
222 |
-
batches_processed += 1
|
223 |
-
self.logger.debug(f"{translations['mix_wave']} {batches_processed}/{total_batches}")
|
224 |
-
|
225 |
-
tar_waves = self.run_model(mix_wave, is_match_mix=is_match_mix)
|
226 |
-
|
227 |
-
if window is not None:
|
228 |
-
tar_waves[..., :chunk_size_actual] *= window
|
229 |
-
divider[..., start:end] += window
|
230 |
-
else: divider[..., start:end] += 1
|
231 |
-
|
232 |
-
result[..., start:end] += tar_waves[..., : end - start]
|
233 |
-
|
234 |
-
|
235 |
-
self.logger.debug(translations["normalization_2"])
|
236 |
-
tar_waves = result / divider
|
237 |
-
tar_waves_.append(tar_waves)
|
238 |
-
|
239 |
-
tar_waves = np.concatenate(np.vstack(tar_waves_)[:, :, self.trim : -self.trim], axis=-1)[:, : mix.shape[-1]]
|
240 |
-
|
241 |
-
source = tar_waves[:, 0:None]
|
242 |
-
self.logger.debug(f"{translations['tar_waves']}: {tar_waves.shape}")
|
243 |
-
|
244 |
-
if not is_match_mix:
|
245 |
-
source *= self.compensate
|
246 |
-
self.logger.debug(translations["mix_match"])
|
247 |
-
|
248 |
-
self.logger.debug(translations["mix_success"])
|
249 |
-
return source
|
250 |
-
|
251 |
-
def run_model(self, mix, is_match_mix=False):
|
252 |
-
spek = self.stft(mix.to(self.torch_device))
|
253 |
-
self.logger.debug(translations["stft_2"].format(shape=spek.shape))
|
254 |
-
|
255 |
-
spek[:, :, :3, :] *= 0
|
256 |
-
|
257 |
-
if is_match_mix:
|
258 |
-
spec_pred = spek.cpu().numpy()
|
259 |
-
self.logger.debug(translations["is_match_mix"])
|
260 |
-
else:
|
261 |
-
if self.enable_denoise:
|
262 |
-
spec_pred_neg = self.model_run(-spek)
|
263 |
-
spec_pred_pos = self.model_run(spek)
|
264 |
-
spec_pred = (spec_pred_neg * -0.5) + (spec_pred_pos * 0.5)
|
265 |
-
self.logger.debug(translations["enable_denoise"])
|
266 |
-
else:
|
267 |
-
spec_pred = self.model_run(spek)
|
268 |
-
self.logger.debug(translations["no_denoise"])
|
269 |
-
|
270 |
-
result = self.stft.inverse(torch.tensor(spec_pred).to(self.torch_device)).cpu().detach().numpy()
|
271 |
-
self.logger.debug(f"{translations['stft']}: {result.shape}")
|
272 |
-
|
273 |
-
return result
|
274 |
-
|
275 |
-
class STFT:
|
276 |
-
def __init__(self, logger, n_fft, hop_length, dim_f, device):
|
277 |
-
self.logger = logger
|
278 |
-
self.n_fft = n_fft
|
279 |
-
self.hop_length = hop_length
|
280 |
-
self.dim_f = dim_f
|
281 |
-
self.device = device
|
282 |
-
self.hann_window = torch.hann_window(window_length=self.n_fft, periodic=True)
|
283 |
-
|
284 |
-
def __call__(self, input_tensor):
|
285 |
-
is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
|
286 |
-
|
287 |
-
if is_non_standard_device: input_tensor = input_tensor.cpu()
|
288 |
-
|
289 |
-
batch_dimensions = input_tensor.shape[:-2]
|
290 |
-
channel_dim, time_dim = input_tensor.shape[-2:]
|
291 |
-
|
292 |
-
permuted_stft_output = torch.stft(input_tensor.reshape([-1, time_dim]), n_fft=self.n_fft, hop_length=self.hop_length, window=self.hann_window.to(input_tensor.device), center=True, return_complex=False).permute([0, 3, 1, 2])
|
293 |
-
final_output = permuted_stft_output.reshape([*batch_dimensions, channel_dim, 2, -1, permuted_stft_output.shape[-1]]).reshape([*batch_dimensions, channel_dim * 2, -1, permuted_stft_output.shape[-1]])
|
294 |
-
|
295 |
-
if is_non_standard_device: final_output = final_output.to(self.device)
|
296 |
-
return final_output[..., : self.dim_f, :]
|
297 |
-
|
298 |
-
def pad_frequency_dimension(self, input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins):
|
299 |
-
return torch.cat([input_tensor, torch.zeros([*batch_dimensions, channel_dim, num_freq_bins - freq_dim, time_dim]).to(input_tensor.device)], -2)
|
300 |
-
|
301 |
-
def calculate_inverse_dimensions(self, input_tensor):
|
302 |
-
channel_dim, freq_dim, time_dim = input_tensor.shape[-3:]
|
303 |
-
|
304 |
-
return input_tensor.shape[:-3], channel_dim, freq_dim, time_dim, self.n_fft // 2 + 1
|
305 |
-
|
306 |
-
def prepare_for_istft(self, padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim):
|
307 |
-
permuted_tensor = padded_tensor.reshape([*batch_dimensions, channel_dim // 2, 2, num_freq_bins, time_dim]).reshape([-1, 2, num_freq_bins, time_dim]).permute([0, 2, 3, 1])
|
308 |
-
|
309 |
-
return permuted_tensor[..., 0] + permuted_tensor[..., 1] * 1.0j
|
310 |
-
|
311 |
-
def inverse(self, input_tensor):
|
312 |
-
is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
|
313 |
-
if is_non_standard_device: input_tensor = input_tensor.cpu()
|
314 |
-
|
315 |
-
batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins = self.calculate_inverse_dimensions(input_tensor)
|
316 |
-
final_output = torch.istft(self.prepare_for_istft(self.pad_frequency_dimension(input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins), batch_dimensions, channel_dim, num_freq_bins, time_dim), n_fft=self.n_fft, hop_length=self.hop_length, window=self.hann_window.to(input_tensor.device), center=True).reshape([*batch_dimensions, 2, -1])
|
317 |
-
|
318 |
-
if is_non_standard_device: final_output = final_output.to(self.device)
|
319 |
-
|
320 |
-
return final_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/audioldm2/models.py
DELETED
@@ -1,330 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import torch
|
4 |
-
import librosa
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import torch.nn.functional as F
|
8 |
-
|
9 |
-
from scipy.signal import get_window
|
10 |
-
from librosa.util import pad_center
|
11 |
-
from diffusers import DDIMScheduler, AudioLDM2Pipeline
|
12 |
-
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
|
13 |
-
from transformers import RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer
|
14 |
-
|
15 |
-
sys.path.append(os.getcwd())
|
16 |
-
|
17 |
-
from main.configs.config import Config
|
18 |
-
from main.library.utils import check_audioldm2
|
19 |
-
|
20 |
-
config = Config()
|
21 |
-
|
22 |
-
class Pipeline(torch.nn.Module):
|
23 |
-
def __init__(self, model_id, device, double_precision = False, token = None, *args, **kwargs):
|
24 |
-
super().__init__(*args, **kwargs)
|
25 |
-
self.model_id = model_id
|
26 |
-
self.device = device
|
27 |
-
self.double_precision = double_precision
|
28 |
-
self.token = token
|
29 |
-
|
30 |
-
def load_scheduler(self):
|
31 |
-
pass
|
32 |
-
|
33 |
-
def get_melspectrogram(self):
|
34 |
-
pass
|
35 |
-
|
36 |
-
def vae_encode(self, x):
|
37 |
-
pass
|
38 |
-
|
39 |
-
def vae_decode(self, x):
|
40 |
-
pass
|
41 |
-
|
42 |
-
def decode_to_mel(self, x):
|
43 |
-
pass
|
44 |
-
|
45 |
-
def setup_extra_inputs(self, *args, **kwargs):
|
46 |
-
pass
|
47 |
-
|
48 |
-
def encode_text(self, prompts, **kwargs):
|
49 |
-
pass
|
50 |
-
|
51 |
-
def get_variance(self, timestep, prev_timestep):
|
52 |
-
pass
|
53 |
-
|
54 |
-
def get_alpha_prod_t_prev(self, prev_timestep):
|
55 |
-
pass
|
56 |
-
|
57 |
-
def get_noise_shape(self, x0, num_steps):
|
58 |
-
return (num_steps, self.model.unet.config.in_channels, x0.shape[-2], x0.shape[-1])
|
59 |
-
|
60 |
-
def sample_xts_from_x0(self, x0, num_inference_steps = 50):
|
61 |
-
alpha_bar = self.model.scheduler.alphas_cumprod
|
62 |
-
sqrt_one_minus_alpha_bar = (1 - alpha_bar) ** 0.5
|
63 |
-
timesteps = self.model.scheduler.timesteps.to(self.device)
|
64 |
-
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
65 |
-
xts = torch.zeros(self.get_noise_shape(x0, num_inference_steps + 1)).to(x0.device)
|
66 |
-
xts[0] = x0
|
67 |
-
|
68 |
-
for t in reversed(timesteps):
|
69 |
-
idx = num_inference_steps - t_to_idx[int(t)]
|
70 |
-
xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
|
71 |
-
|
72 |
-
return xts
|
73 |
-
|
74 |
-
def get_zs_from_xts(self, xt, xtm1, noise_pred, t, eta = 0, numerical_fix = True, **kwargs):
|
75 |
-
alpha_bar = self.model.scheduler.alphas_cumprod
|
76 |
-
|
77 |
-
if self.model.scheduler.config.prediction_type == 'epsilon': pred_original_sample = (xt - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5
|
78 |
-
elif self.model.scheduler.config.prediction_type == 'v_prediction': pred_original_sample = (alpha_bar[t] ** 0.5) * xt - ((1 - alpha_bar[t]) ** 0.5) * noise_pred
|
79 |
-
|
80 |
-
prev_timestep = t - self.model.scheduler.config.num_train_timesteps // self.model.scheduler.num_inference_steps
|
81 |
-
alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
|
82 |
-
variance = self.get_variance(t, prev_timestep)
|
83 |
-
|
84 |
-
if self.model.scheduler.config.prediction_type == 'epsilon': radom_noise_pred = noise_pred
|
85 |
-
elif self.model.scheduler.config.prediction_type == 'v_prediction': radom_noise_pred = (alpha_bar[t] ** 0.5) * noise_pred + ((1 - alpha_bar[t]) ** 0.5) * xt
|
86 |
-
|
87 |
-
mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + ((1 - alpha_prod_t_prev - eta * variance) ** (0.5) * radom_noise_pred)
|
88 |
-
z = (xtm1 - mu_xt) / (eta * variance ** 0.5)
|
89 |
-
|
90 |
-
if numerical_fix: xtm1 = mu_xt + (eta * variance ** 0.5)*z
|
91 |
-
return z, xtm1, None
|
92 |
-
|
93 |
-
def reverse_step_with_custom_noise(self, model_output, timestep, sample, variance_noise = None, eta = 0, **kwargs):
|
94 |
-
prev_timestep = timestep - self.model.scheduler.config.num_train_timesteps // self.model.scheduler.num_inference_steps
|
95 |
-
alpha_prod_t = self.model.scheduler.alphas_cumprod[timestep]
|
96 |
-
alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
|
97 |
-
beta_prod_t = 1 - alpha_prod_t
|
98 |
-
|
99 |
-
if self.model.scheduler.config.prediction_type == 'epsilon': pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
100 |
-
elif self.model.scheduler.config.prediction_type == 'v_prediction': pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output
|
101 |
-
|
102 |
-
variance = self.get_variance(timestep, prev_timestep)
|
103 |
-
|
104 |
-
if self.model.scheduler.config.prediction_type == 'epsilon': model_output_direction = model_output
|
105 |
-
elif self.model.scheduler.config.prediction_type == 'v_prediction': model_output_direction = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
106 |
-
|
107 |
-
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + ((1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction)
|
108 |
-
|
109 |
-
if eta > 0:
|
110 |
-
if variance_noise is None: variance_noise = torch.randn(model_output.shape, device=self.device)
|
111 |
-
prev_sample = prev_sample + (eta * variance ** (0.5) * variance_noise)
|
112 |
-
|
113 |
-
return prev_sample
|
114 |
-
|
115 |
-
def unet_forward(self, sample, timestep, encoder_hidden_states, class_labels = None, timestep_cond = None, attention_mask = None, cross_attention_kwargs = None, added_cond_kwargs = None, down_block_additional_residuals = None, mid_block_additional_residual = None, encoder_attention_mask = None, replace_h_space = None, replace_skip_conns = None, return_dict = True, zero_out_resconns = None):
|
116 |
-
pass
|
117 |
-
|
118 |
-
class STFT(torch.nn.Module):
|
119 |
-
def __init__(self, fft_size, hop_size, window_size, window_type="hann"):
|
120 |
-
super().__init__()
|
121 |
-
self.fft_size = fft_size
|
122 |
-
self.hop_size = hop_size
|
123 |
-
self.window_size = window_size
|
124 |
-
self.window_type = window_type
|
125 |
-
|
126 |
-
scale = fft_size / hop_size
|
127 |
-
fourier_basis = np.fft.fft(np.eye(fft_size))
|
128 |
-
|
129 |
-
cutoff = fft_size // 2 + 1
|
130 |
-
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])])
|
131 |
-
|
132 |
-
self.forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
133 |
-
self.inverse_basis = torch.FloatTensor(np.linalg.pinv(scale * fourier_basis).T[:, None, :])
|
134 |
-
|
135 |
-
if window_type:
|
136 |
-
assert fft_size >= window_size
|
137 |
-
|
138 |
-
fft_window = torch.from_numpy(pad_center(get_window(window_type, window_size, fftbins=True), size=fft_size)).float()
|
139 |
-
self.forward_basis *= fft_window
|
140 |
-
self.inverse_basis *= fft_window
|
141 |
-
|
142 |
-
if not hasattr(self, "forward_basis"): self.register_buffer("forward_basis", self.forward_basis)
|
143 |
-
if not hasattr(self, "inverse_basis"): self.register_buffer("inverse_basis", self.inverse_basis)
|
144 |
-
|
145 |
-
def transform(self, signal):
|
146 |
-
batch_size, num_samples = signal.shape
|
147 |
-
transformed_signal = F.conv1d(F.pad(signal.view(batch_size, 1, num_samples).unsqueeze(1), (self.fft_size // 2, self.fft_size // 2, 0, 0), mode="reflect").squeeze(1), self.forward_basis, stride=self.hop_size, padding=0).cpu()
|
148 |
-
|
149 |
-
cutoff = self.fft_size // 2 + 1
|
150 |
-
real_part, imag_part = transformed_signal[:, :cutoff, :], transformed_signal[:, cutoff:, :]
|
151 |
-
|
152 |
-
return torch.sqrt(real_part ** 2 + imag_part ** 2), torch.atan2(imag_part, real_part)
|
153 |
-
|
154 |
-
class MelSpectrogramProcessor(torch.nn.Module):
|
155 |
-
def __init__(self, fft_size, hop_size, window_size, num_mel_bins, sample_rate, fmin, fmax):
|
156 |
-
super().__init__()
|
157 |
-
self.num_mel_bins = num_mel_bins
|
158 |
-
self.sample_rate = sample_rate
|
159 |
-
self.stft_processor = STFT(fft_size, hop_size, window_size)
|
160 |
-
self.register_buffer("mel_filter", torch.from_numpy(librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mel_bins, fmin=fmin, fmax=fmax)).float())
|
161 |
-
|
162 |
-
def compute_mel_spectrogram(self, waveform, normalization_fn=torch.log):
|
163 |
-
assert torch.min(waveform) >= -1
|
164 |
-
assert torch.max(waveform) <= 1
|
165 |
-
|
166 |
-
magnitudes, _ = self.stft_processor.transform(waveform)
|
167 |
-
return normalization_fn(torch.clamp(torch.matmul(self.mel_filter, magnitudes), min=1e-5))
|
168 |
-
|
169 |
-
class AudioLDM2(Pipeline):
|
170 |
-
def __init__(self, *args, **kwargs):
|
171 |
-
super().__init__(*args, **kwargs)
|
172 |
-
self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, local_files_only=True, torch_dtype=torch.float16 if config.is_half else torch.float32).to(self.device)
|
173 |
-
|
174 |
-
def load_scheduler(self):
|
175 |
-
self.model.scheduler = DDIMScheduler.from_pretrained(self.model_id, subfolder="scheduler")
|
176 |
-
|
177 |
-
def get_melspectrogram(self):
|
178 |
-
return MelSpectrogramProcessor(fft_size=1024, hop_size=160, window_size=1024, num_mel_bins=64, sample_rate=16000, fmin=0, fmax=8000)
|
179 |
-
|
180 |
-
def vae_encode(self, x):
|
181 |
-
if x.shape[2] % 4: x = F.pad(x, (0, 0, 4 - (x.shape[2] % 4), 0))
|
182 |
-
output = (self.model.vae.encode(x.half() if config.is_half else x.float()).latent_dist.mode() * self.model.vae.config.scaling_factor)
|
183 |
-
return output.half() if config.is_half else output.float()
|
184 |
-
|
185 |
-
def vae_decode(self, x):
|
186 |
-
return self.model.vae.decode(1 / self.model.vae.config.scaling_factor * x).sample
|
187 |
-
|
188 |
-
def decode_to_mel(self, x):
|
189 |
-
tmp = self.model.mel_spectrogram_to_waveform(x[:, 0].detach().to(torch.float16 if config.is_half else torch.float32)).detach()
|
190 |
-
|
191 |
-
if len(tmp.shape) == 1: tmp = tmp.unsqueeze(0)
|
192 |
-
return tmp
|
193 |
-
|
194 |
-
def encode_text(self, prompts, negative = False, save_compute = False, cond_length = 0, **kwargs):
|
195 |
-
tokenizers, text_encoders = [self.model.tokenizer, self.model.tokenizer_2], [self.model.text_encoder, self.model.text_encoder_2]
|
196 |
-
prompt_embeds_list, attention_mask_list = [], []
|
197 |
-
|
198 |
-
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
199 |
-
text_inputs = tokenizer(prompts, padding="max_length" if (save_compute and negative) or isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) else True, max_length=tokenizer.model_max_length if (not save_compute) or ((not negative) or isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer))) else cond_length, truncation=True, return_tensors="pt")
|
200 |
-
text_input_ids = text_inputs.input_ids
|
201 |
-
|
202 |
-
attention_mask = text_inputs.attention_mask
|
203 |
-
untruncated_ids = tokenizer(prompts, padding="longest", return_tensors="pt").input_ids
|
204 |
-
|
205 |
-
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1: -1])
|
206 |
-
|
207 |
-
text_input_ids = text_input_ids.to(self.device)
|
208 |
-
attention_mask = attention_mask.to(self.device)
|
209 |
-
|
210 |
-
with torch.no_grad():
|
211 |
-
if text_encoder.config.model_type == "clap":
|
212 |
-
prompt_embeds = text_encoder.get_text_features(text_input_ids, attention_mask=attention_mask)
|
213 |
-
prompt_embeds = prompt_embeds[:, None, :]
|
214 |
-
attention_mask = attention_mask.new_ones((len(prompts), 1))
|
215 |
-
else: prompt_embeds = text_encoder(text_input_ids, attention_mask=attention_mask)[0]
|
216 |
-
|
217 |
-
prompt_embeds_list.append(prompt_embeds)
|
218 |
-
attention_mask_list.append(attention_mask)
|
219 |
-
|
220 |
-
projection_output = self.model.projection_model(hidden_states=prompt_embeds_list[0], hidden_states_1=prompt_embeds_list[1], attention_mask=attention_mask_list[0], attention_mask_1=attention_mask_list[1])
|
221 |
-
generated_prompt_embeds = self.model.generate_language_model(projection_output.hidden_states, attention_mask=projection_output.attention_mask, max_new_tokens=None)
|
222 |
-
prompt_embeds = prompt_embeds.to(dtype=self.model.text_encoder_2.dtype, device=self.device)
|
223 |
-
return generated_prompt_embeds.to(dtype=self.model.language_model.dtype, device=self.device), prompt_embeds, (attention_mask.to(device=self.device) if attention_mask is not None else torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=self.device))
|
224 |
-
|
225 |
-
def get_variance(self, timestep, prev_timestep):
|
226 |
-
alpha_prod_t = self.model.scheduler.alphas_cumprod[timestep]
|
227 |
-
alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
|
228 |
-
return ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
229 |
-
|
230 |
-
def get_alpha_prod_t_prev(self, prev_timestep):
|
231 |
-
return self.model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.model.scheduler.final_alpha_cumprod
|
232 |
-
|
233 |
-
def unet_forward(self, sample, timestep, encoder_hidden_states, timestep_cond = None, class_labels = None, attention_mask = None, encoder_attention_mask = None, return_dict = True, cross_attention_kwargs = None, mid_block_additional_residual = None, replace_h_space = None, replace_skip_conns = None, zero_out_resconns = None):
|
234 |
-
encoder_hidden_states_1 = class_labels
|
235 |
-
class_labels = None
|
236 |
-
encoder_attention_mask_1 = encoder_attention_mask
|
237 |
-
encoder_attention_mask = None
|
238 |
-
default_overall_up_factor = 2 ** self.model.unet.num_upsamplers
|
239 |
-
forward_upsample_size = False
|
240 |
-
upsample_size = None
|
241 |
-
|
242 |
-
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): forward_upsample_size = True
|
243 |
-
|
244 |
-
if attention_mask is not None:
|
245 |
-
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
246 |
-
attention_mask = attention_mask.unsqueeze(1)
|
247 |
-
|
248 |
-
if encoder_attention_mask is not None:
|
249 |
-
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
250 |
-
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
251 |
-
|
252 |
-
if encoder_attention_mask_1 is not None:
|
253 |
-
encoder_attention_mask_1 = (1 - encoder_attention_mask_1.to(sample.dtype)) * -10000.0
|
254 |
-
encoder_attention_mask_1 = encoder_attention_mask_1.unsqueeze(1)
|
255 |
-
|
256 |
-
timesteps = timestep
|
257 |
-
if not torch.is_tensor(timesteps):
|
258 |
-
is_mps = sample.device.type == "mps"
|
259 |
-
|
260 |
-
dtype = (torch.float16 if is_mps else torch.float32) if isinstance(timestep, float) else (torch.int16 if is_mps else torch.int32)
|
261 |
-
|
262 |
-
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
263 |
-
elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device)
|
264 |
-
|
265 |
-
emb = self.model.unet.time_embedding(self.model.unet.time_proj(timesteps.expand(sample.shape[0])).to(dtype=sample.dtype), timestep_cond)
|
266 |
-
aug_emb = None
|
267 |
-
|
268 |
-
if self.model.unet.class_embedding is not None:
|
269 |
-
if class_labels is None: raise ValueError
|
270 |
-
|
271 |
-
if self.model.unet.config.class_embed_type == "timestep": class_labels = self.model.unet.time_proj(class_labels).to(dtype=sample.dtype)
|
272 |
-
class_emb = self.model.unet.class_embedding(class_labels).to(dtype=sample.dtype)
|
273 |
-
|
274 |
-
if self.model.unet.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1)
|
275 |
-
else: emb = emb + class_emb
|
276 |
-
|
277 |
-
emb = emb + aug_emb if aug_emb is not None else emb
|
278 |
-
if self.model.unet.time_embed_act is not None: emb = self.model.unet.time_embed_act(emb)
|
279 |
-
|
280 |
-
sample = self.model.unet.conv_in(sample)
|
281 |
-
down_block_res_samples = (sample,)
|
282 |
-
|
283 |
-
for downsample_block in self.model.unet.down_blocks:
|
284 |
-
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, encoder_hidden_states_1=encoder_hidden_states_1, encoder_attention_mask_1=encoder_attention_mask_1)
|
285 |
-
else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
286 |
-
|
287 |
-
down_block_res_samples += res_samples
|
288 |
-
|
289 |
-
if self.model.unet.mid_block is not None: sample = self.model.unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, encoder_hidden_states_1=encoder_hidden_states_1, encoder_attention_mask_1=encoder_attention_mask_1)
|
290 |
-
|
291 |
-
if replace_h_space is None: h_space = sample.clone()
|
292 |
-
else:
|
293 |
-
h_space = replace_h_space
|
294 |
-
sample = replace_h_space.clone()
|
295 |
-
|
296 |
-
if mid_block_additional_residual is not None: sample = sample + mid_block_additional_residual
|
297 |
-
extracted_res_conns = {}
|
298 |
-
|
299 |
-
for i, upsample_block in enumerate(self.model.unet.up_blocks):
|
300 |
-
is_final_block = i == len(self.model.unet.up_blocks) - 1
|
301 |
-
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
302 |
-
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
303 |
-
|
304 |
-
if replace_skip_conns is not None and replace_skip_conns.get(i): res_samples = replace_skip_conns.get(i)
|
305 |
-
|
306 |
-
if zero_out_resconns is not None:
|
307 |
-
if (type(zero_out_resconns) is int and i >= (zero_out_resconns - 1)) or type(zero_out_resconns) is list and i in zero_out_resconns: res_samples = [torch.zeros_like(x) for x in res_samples]
|
308 |
-
|
309 |
-
extracted_res_conns[i] = res_samples
|
310 |
-
if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:]
|
311 |
-
|
312 |
-
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, encoder_hidden_states_1=encoder_hidden_states_1, encoder_attention_mask_1=encoder_attention_mask_1)
|
313 |
-
else: sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size)
|
314 |
-
|
315 |
-
if self.model.unet.conv_norm_out: sample = self.model.unet.conv_act(self.model.unet.conv_norm_out(sample))
|
316 |
-
sample = self.model.unet.conv_out(sample)
|
317 |
-
|
318 |
-
if not return_dict: return (sample,)
|
319 |
-
return UNet2DConditionOutput(sample=sample), h_space, extracted_res_conns
|
320 |
-
|
321 |
-
def load_model(model, device):
|
322 |
-
check_audioldm2(model)
|
323 |
-
|
324 |
-
ldm_stable = AudioLDM2(model_id=os.path.join("assets", "models", "audioldm2", model), device=device, double_precision=False)
|
325 |
-
ldm_stable.load_scheduler()
|
326 |
-
|
327 |
-
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
328 |
-
elif torch.backends.mps.is_available(): torch.mps.empty_cache()
|
329 |
-
|
330 |
-
return ldm_stable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/audioldm2/utils.py
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import librosa
|
3 |
-
import torchaudio
|
4 |
-
|
5 |
-
import numpy as np
|
6 |
-
|
7 |
-
def compute_mel_spectrogram(audio, stft_processor):
|
8 |
-
return stft_processor.compute_mel_spectrogram(torch.autograd.Variable(torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1), requires_grad=False)).squeeze(0).numpy().astype(np.float32)
|
9 |
-
|
10 |
-
def pad_spectrogram(spectrogram, target_length=1024):
|
11 |
-
pad_amount = target_length - spectrogram.shape[0]
|
12 |
-
spectrogram = torch.nn.functional.pad(spectrogram, (0, 0, 0, pad_amount)) if pad_amount > 0 else spectrogram[:target_length, :]
|
13 |
-
|
14 |
-
if spectrogram.size(-1) % 2 != 0: spectrogram = spectrogram[..., :-1]
|
15 |
-
return spectrogram
|
16 |
-
|
17 |
-
def pad_waveform(waveform, segment_length):
|
18 |
-
waveform_length = waveform.shape[-1]
|
19 |
-
assert waveform_length > 100
|
20 |
-
|
21 |
-
if segment_length is None or waveform_length == segment_length: return waveform
|
22 |
-
elif waveform_length > segment_length: return waveform[:, :segment_length]
|
23 |
-
|
24 |
-
padded_waveform = np.zeros((1, segment_length))
|
25 |
-
padded_waveform[:, :waveform_length] = waveform
|
26 |
-
return padded_waveform
|
27 |
-
|
28 |
-
def normalize(waveform):
|
29 |
-
waveform -= np.mean(waveform)
|
30 |
-
return (waveform / (np.max(np.abs(waveform)) + 1e-8)) * 0.5
|
31 |
-
|
32 |
-
def process_audio(y, sr, segment_length):
|
33 |
-
normalized_waveform = normalize(torchaudio.functional.resample(torch.from_numpy(y), orig_freq=sr, new_freq=16000).numpy())[None, ...]
|
34 |
-
return 0.5 * (pad_waveform(normalized_waveform, segment_length) / np.max(np.abs(normalized_waveform)))
|
35 |
-
|
36 |
-
def load_audio(audio_path, stft_processor, device=None):
|
37 |
-
y, sr = librosa.load(audio_path, sr=None)
|
38 |
-
duration = len(y) / sr
|
39 |
-
|
40 |
-
return pad_spectrogram(torch.FloatTensor(compute_mel_spectrogram(torch.FloatTensor(process_audio(y, sr, int(duration * 102.4) * 160)[0, ...]), stft_processor).T), int(duration * 102.4)).unsqueeze(0).unsqueeze(0).to(device), duration
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/predictors/CREPE.py
DELETED
@@ -1,210 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import torch
|
3 |
-
import librosa
|
4 |
-
import functools
|
5 |
-
import scipy.stats
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
|
9 |
-
CENTS_PER_BIN, MAX_FMAX, PITCH_BINS, SAMPLE_RATE, WINDOW_SIZE = 20, 2006, 360, 16000, 1024
|
10 |
-
|
11 |
-
class Crepe(torch.nn.Module):
|
12 |
-
def __init__(self, model='full'):
|
13 |
-
super().__init__()
|
14 |
-
if model == 'full':
|
15 |
-
in_channels = [1, 1024, 128, 128, 128, 256]
|
16 |
-
out_channels = [1024, 128, 128, 128, 256, 512]
|
17 |
-
self.in_features = 2048
|
18 |
-
elif model == 'large':
|
19 |
-
in_channels = [1, 768, 96, 96, 96, 192]
|
20 |
-
out_channels = [768, 96, 96, 96, 192, 384]
|
21 |
-
self.in_features = 1536
|
22 |
-
elif model == 'medium':
|
23 |
-
in_channels = [1, 512, 64, 64, 64, 128]
|
24 |
-
out_channels = [512, 64, 64, 64, 128, 256]
|
25 |
-
self.in_features = 1024
|
26 |
-
elif model == 'small':
|
27 |
-
in_channels = [1, 256, 32, 32, 32, 64]
|
28 |
-
out_channels = [256, 32, 32, 32, 64, 128]
|
29 |
-
self.in_features = 512
|
30 |
-
elif model == 'tiny':
|
31 |
-
in_channels = [1, 128, 16, 16, 16, 32]
|
32 |
-
out_channels = [128, 16, 16, 16, 32, 64]
|
33 |
-
self.in_features = 256
|
34 |
-
|
35 |
-
kernel_sizes = [(512, 1)] + 5 * [(64, 1)]
|
36 |
-
strides = [(4, 1)] + 5 * [(1, 1)]
|
37 |
-
|
38 |
-
batch_norm_fn = functools.partial(torch.nn.BatchNorm2d, eps=0.0010000000474974513, momentum=0.0)
|
39 |
-
|
40 |
-
self.conv1 = torch.nn.Conv2d(in_channels=in_channels[0], out_channels=out_channels[0], kernel_size=kernel_sizes[0], stride=strides[0])
|
41 |
-
self.conv1_BN = batch_norm_fn(num_features=out_channels[0])
|
42 |
-
self.conv2 = torch.nn.Conv2d(in_channels=in_channels[1], out_channels=out_channels[1], kernel_size=kernel_sizes[1], stride=strides[1])
|
43 |
-
self.conv2_BN = batch_norm_fn(num_features=out_channels[1])
|
44 |
-
|
45 |
-
self.conv3 = torch.nn.Conv2d(in_channels=in_channels[2], out_channels=out_channels[2], kernel_size=kernel_sizes[2], stride=strides[2])
|
46 |
-
self.conv3_BN = batch_norm_fn(num_features=out_channels[2])
|
47 |
-
self.conv4 = torch.nn.Conv2d(in_channels=in_channels[3], out_channels=out_channels[3], kernel_size=kernel_sizes[3], stride=strides[3])
|
48 |
-
self.conv4_BN = batch_norm_fn(num_features=out_channels[3])
|
49 |
-
|
50 |
-
self.conv5 = torch.nn.Conv2d(in_channels=in_channels[4], out_channels=out_channels[4], kernel_size=kernel_sizes[4], stride=strides[4])
|
51 |
-
self.conv5_BN = batch_norm_fn(num_features=out_channels[4])
|
52 |
-
self.conv6 = torch.nn.Conv2d(in_channels=in_channels[5], out_channels=out_channels[5], kernel_size=kernel_sizes[5], stride=strides[5])
|
53 |
-
self.conv6_BN = batch_norm_fn(num_features=out_channels[5])
|
54 |
-
|
55 |
-
self.classifier = torch.nn.Linear(in_features=self.in_features, out_features=PITCH_BINS)
|
56 |
-
|
57 |
-
def forward(self, x, embed=False):
|
58 |
-
x = self.embed(x)
|
59 |
-
if embed: return x
|
60 |
-
|
61 |
-
return torch.sigmoid(self.classifier(self.layer(x, self.conv6, self.conv6_BN).permute(0, 2, 1, 3).reshape(-1, self.in_features)))
|
62 |
-
|
63 |
-
def embed(self, x):
|
64 |
-
x = x[:, None, :, None]
|
65 |
-
|
66 |
-
return self.layer(self.layer(self.layer(self.layer(self.layer(x, self.conv1, self.conv1_BN, (0, 0, 254, 254)), self.conv2, self.conv2_BN), self.conv3, self.conv3_BN), self.conv4, self.conv4_BN), self.conv5, self.conv5_BN)
|
67 |
-
|
68 |
-
def layer(self, x, conv, batch_norm, padding=(0, 0, 31, 32)):
|
69 |
-
return torch.nn.functional.max_pool2d(batch_norm(torch.nn.functional.relu(conv(torch.nn.functional.pad(x, padding)))), (2, 1), (2, 1))
|
70 |
-
|
71 |
-
def viterbi(logits):
|
72 |
-
if not hasattr(viterbi, 'transition'):
|
73 |
-
xx, yy = np.meshgrid(range(360), range(360))
|
74 |
-
transition = np.maximum(12 - abs(xx - yy), 0)
|
75 |
-
viterbi.transition = transition / transition.sum(axis=1, keepdims=True)
|
76 |
-
|
77 |
-
with torch.no_grad():
|
78 |
-
probs = torch.nn.functional.softmax(logits, dim=1)
|
79 |
-
|
80 |
-
bins = torch.tensor(np.array([librosa.sequence.viterbi(sequence, viterbi.transition).astype(np.int64) for sequence in probs.cpu().numpy()]), device=probs.device)
|
81 |
-
return bins, bins_to_frequency(bins)
|
82 |
-
|
83 |
-
def predict(audio, sample_rate, hop_length=None, fmin=50, fmax=MAX_FMAX, model='full', return_periodicity=False, batch_size=None, device='cpu', pad=True, providers=None, onnx=False):
|
84 |
-
results = []
|
85 |
-
|
86 |
-
if onnx:
|
87 |
-
import onnxruntime as ort
|
88 |
-
|
89 |
-
sess_options = ort.SessionOptions()
|
90 |
-
sess_options.log_severity_level = 3
|
91 |
-
|
92 |
-
session = ort.InferenceSession(os.path.join("assets", "models", "predictors", f"crepe_{model}.onnx"), sess_options=sess_options, providers=providers)
|
93 |
-
|
94 |
-
for frames in preprocess(audio, sample_rate, hop_length, batch_size, device, pad):
|
95 |
-
result = postprocess(torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: frames.cpu().numpy()})[0].transpose(1, 0)[None]), fmin, fmax, return_periodicity)
|
96 |
-
results.append((result[0], result[1]) if isinstance(result, tuple) else result)
|
97 |
-
|
98 |
-
del session
|
99 |
-
|
100 |
-
if return_periodicity:
|
101 |
-
pitch, periodicity = zip(*results)
|
102 |
-
return torch.cat(pitch, 1), torch.cat(periodicity, 1)
|
103 |
-
|
104 |
-
return torch.cat(results, 1)
|
105 |
-
else:
|
106 |
-
with torch.no_grad():
|
107 |
-
for frames in preprocess(audio, sample_rate, hop_length, batch_size, device, pad):
|
108 |
-
result = postprocess(infer(frames, model, device, embed=False).reshape(audio.size(0), -1, PITCH_BINS).transpose(1, 2), fmin, fmax, return_periodicity)
|
109 |
-
results.append((result[0].to(audio.device), result[1].to(audio.device)) if isinstance(result, tuple) else result.to(audio.device))
|
110 |
-
|
111 |
-
if return_periodicity:
|
112 |
-
pitch, periodicity = zip(*results)
|
113 |
-
return torch.cat(pitch, 1), torch.cat(periodicity, 1)
|
114 |
-
|
115 |
-
return torch.cat(results, 1)
|
116 |
-
|
117 |
-
def bins_to_frequency(bins):
|
118 |
-
cents = CENTS_PER_BIN * bins + 1997.3794084376191
|
119 |
-
return 10 * 2 ** ((cents + cents.new_tensor(scipy.stats.triang.rvs(c=0.5, loc=-CENTS_PER_BIN, scale=2 * CENTS_PER_BIN, size=cents.size()))) / 1200)
|
120 |
-
|
121 |
-
def frequency_to_bins(frequency, quantize_fn=torch.floor):
|
122 |
-
return quantize_fn(((1200 * torch.log2(frequency / 10)) - 1997.3794084376191) / CENTS_PER_BIN).int()
|
123 |
-
|
124 |
-
def infer(frames, model='full', device='cpu', embed=False):
|
125 |
-
if not hasattr(infer, 'model') or not hasattr(infer, 'capacity') or (hasattr(infer, 'capacity') and infer.capacity != model): load_model(device, model)
|
126 |
-
infer.model = infer.model.to(device)
|
127 |
-
|
128 |
-
return infer.model(frames, embed=embed)
|
129 |
-
|
130 |
-
def load_model(device, capacity='full'):
|
131 |
-
infer.capacity = capacity
|
132 |
-
infer.model = Crepe(capacity)
|
133 |
-
infer.model.load_state_dict(torch.load(os.path.join("assets", "models", "predictors", f"crepe_{capacity}.pth"), map_location=device))
|
134 |
-
infer.model = infer.model.to(torch.device(device))
|
135 |
-
infer.model.eval()
|
136 |
-
|
137 |
-
def postprocess(probabilities, fmin=0, fmax=MAX_FMAX, return_periodicity=False):
|
138 |
-
probabilities = probabilities.detach()
|
139 |
-
|
140 |
-
probabilities[:, :frequency_to_bins(torch.tensor(fmin))] = -float('inf')
|
141 |
-
probabilities[:, frequency_to_bins(torch.tensor(fmax), torch.ceil):] = -float('inf')
|
142 |
-
|
143 |
-
bins, pitch = viterbi(probabilities)
|
144 |
-
|
145 |
-
if not return_periodicity: return pitch
|
146 |
-
return pitch, periodicity(probabilities, bins)
|
147 |
-
|
148 |
-
def preprocess(audio, sample_rate, hop_length=None, batch_size=None, device='cpu', pad=True):
|
149 |
-
hop_length = sample_rate // 100 if hop_length is None else hop_length
|
150 |
-
|
151 |
-
if sample_rate != SAMPLE_RATE:
|
152 |
-
audio = torch.tensor(librosa.resample(audio.detach().cpu().numpy().squeeze(0), orig_sr=sample_rate, target_sr=SAMPLE_RATE, res_type="soxr_vhq"), device=audio.device).unsqueeze(0)
|
153 |
-
hop_length = int(hop_length * SAMPLE_RATE / sample_rate)
|
154 |
-
|
155 |
-
if pad:
|
156 |
-
total_frames = 1 + int(audio.size(1) // hop_length)
|
157 |
-
audio = torch.nn.functional.pad(audio, (WINDOW_SIZE // 2, WINDOW_SIZE // 2))
|
158 |
-
else: total_frames = 1 + int((audio.size(1) - WINDOW_SIZE) // hop_length)
|
159 |
-
|
160 |
-
batch_size = total_frames if batch_size is None else batch_size
|
161 |
-
|
162 |
-
for i in range(0, total_frames, batch_size):
|
163 |
-
frames = torch.nn.functional.unfold(audio[:, None, None, max(0, i * hop_length):min(audio.size(1), (i + batch_size - 1) * hop_length + WINDOW_SIZE)], kernel_size=(1, WINDOW_SIZE), stride=(1, hop_length))
|
164 |
-
frames = frames.transpose(1, 2).reshape(-1, WINDOW_SIZE).to(device)
|
165 |
-
frames -= frames.mean(dim=1, keepdim=True)
|
166 |
-
frames /= torch.max(torch.tensor(1e-10, device=frames.device), frames.std(dim=1, keepdim=True))
|
167 |
-
|
168 |
-
yield frames
|
169 |
-
|
170 |
-
def periodicity(probabilities, bins):
|
171 |
-
probs_stacked = probabilities.transpose(1, 2).reshape(-1, PITCH_BINS)
|
172 |
-
periodicity = probs_stacked.gather(1, bins.reshape(-1, 1).to(torch.int64))
|
173 |
-
|
174 |
-
return periodicity.reshape(probabilities.size(0), probabilities.size(2))
|
175 |
-
|
176 |
-
def mean(signals, win_length=9):
|
177 |
-
assert signals.dim() == 2
|
178 |
-
|
179 |
-
signals = signals.unsqueeze(1)
|
180 |
-
mask = ~torch.isnan(signals)
|
181 |
-
padding = win_length // 2
|
182 |
-
|
183 |
-
ones_kernel = torch.ones(signals.size(1), 1, win_length, device=signals.device)
|
184 |
-
avg_pooled = torch.nn.functional.conv1d(torch.where(mask, signals, torch.zeros_like(signals)), ones_kernel, stride=1, padding=padding) / torch.nn.functional.conv1d(mask.float(), ones_kernel, stride=1, padding=padding).clamp(min=1)
|
185 |
-
avg_pooled[avg_pooled == 0] = float("nan")
|
186 |
-
|
187 |
-
return avg_pooled.squeeze(1)
|
188 |
-
|
189 |
-
def median(signals, win_length):
|
190 |
-
assert signals.dim() == 2
|
191 |
-
|
192 |
-
signals = signals.unsqueeze(1)
|
193 |
-
mask = ~torch.isnan(signals)
|
194 |
-
padding = win_length // 2
|
195 |
-
|
196 |
-
x = torch.nn.functional.pad(torch.where(mask, signals, torch.zeros_like(signals)), (padding, padding), mode="reflect")
|
197 |
-
mask = torch.nn.functional.pad(mask.float(), (padding, padding), mode="constant", value=0)
|
198 |
-
|
199 |
-
x = x.unfold(2, win_length, 1)
|
200 |
-
mask = mask.unfold(2, win_length, 1)
|
201 |
-
|
202 |
-
x = x.contiguous().view(x.size()[:3] + (-1,))
|
203 |
-
mask = mask.contiguous().view(mask.size()[:3] + (-1,))
|
204 |
-
|
205 |
-
x_sorted, _ = torch.sort(torch.where(mask.bool(), x.float(), float("inf")).to(x), dim=-1)
|
206 |
-
|
207 |
-
median_pooled = x_sorted.gather(-1, ((mask.sum(dim=-1) - 1) // 2).clamp(min=0).unsqueeze(-1).long()).squeeze(-1)
|
208 |
-
median_pooled[torch.isinf(median_pooled)] = float("nan")
|
209 |
-
|
210 |
-
return median_pooled.squeeze(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/predictors/FCPE.py
DELETED
@@ -1,1097 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import io
|
3 |
-
import math
|
4 |
-
import torch
|
5 |
-
import librosa
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
import soundfile as sf
|
9 |
-
import onnxruntime as ort
|
10 |
-
import torch.nn.functional as F
|
11 |
-
|
12 |
-
from torch import nn, einsum
|
13 |
-
from functools import partial
|
14 |
-
from Crypto.Cipher import AES
|
15 |
-
from Crypto.Util.Padding import unpad
|
16 |
-
from torchaudio.transforms import Resample
|
17 |
-
from einops import rearrange, repeat, pack, unpack
|
18 |
-
from torch.nn.utils.parametrizations import weight_norm
|
19 |
-
|
20 |
-
from librosa.filters import mel as librosa_mel_fn
|
21 |
-
|
22 |
-
os.environ["LRU_CACHE_CAPACITY"] = "3"
|
23 |
-
|
24 |
-
def exists(val):
|
25 |
-
return val is not None
|
26 |
-
|
27 |
-
def default(value, d):
|
28 |
-
return value if exists(value) else d
|
29 |
-
|
30 |
-
def max_neg_value(tensor):
|
31 |
-
return -torch.finfo(tensor.dtype).max
|
32 |
-
|
33 |
-
def empty(tensor):
|
34 |
-
return tensor.numel() == 0
|
35 |
-
|
36 |
-
def l2norm(tensor):
|
37 |
-
return F.normalize(tensor, dim = -1).type(tensor.dtype)
|
38 |
-
|
39 |
-
def decrypt_model(input_path):
|
40 |
-
with open(input_path, "rb") as f:
|
41 |
-
data = f.read()
|
42 |
-
|
43 |
-
with open(os.path.join("main", "configs", "decrypt.bin"), "rb") as f:
|
44 |
-
key = f.read()
|
45 |
-
|
46 |
-
return io.BytesIO(unpad(AES.new(key, AES.MODE_CBC, data[:16]).decrypt(data[16:]), AES.block_size)).read()
|
47 |
-
|
48 |
-
def l2_regularization(model, l2_alpha):
|
49 |
-
l2_loss = []
|
50 |
-
|
51 |
-
for module in model.modules():
|
52 |
-
if type(module) is nn.Conv2d: l2_loss.append((module.weight**2).sum() / 2.0)
|
53 |
-
|
54 |
-
return l2_alpha * sum(l2_loss)
|
55 |
-
|
56 |
-
def pad_to_multiple(tensor, multiple, dim=-1, value=0):
|
57 |
-
seqlen = tensor.shape[dim]
|
58 |
-
m = seqlen / multiple
|
59 |
-
|
60 |
-
if m.is_integer(): return False, tensor
|
61 |
-
return True, F.pad(tensor, (*((0,) * (-1 - dim) * 2), 0, (math.ceil(m) * multiple - seqlen)), value = value)
|
62 |
-
|
63 |
-
def look_around(x, backward = 1, forward = 0, pad_value = -1, dim = 2):
|
64 |
-
t = x.shape[1]
|
65 |
-
dims = (len(x.shape) - dim) * (0, 0)
|
66 |
-
|
67 |
-
padded_x = F.pad(x, (*dims, backward, forward), value = pad_value)
|
68 |
-
return torch.cat([padded_x[:, ind:(ind + t), ...] for ind in range(forward + backward + 1)], dim = dim)
|
69 |
-
|
70 |
-
def rotate_half(x):
|
71 |
-
x1, x2 = rearrange(x, 'b ... (r d) -> b ... r d', r = 2).unbind(dim = -2)
|
72 |
-
return torch.cat((-x2, x1), dim = -1)
|
73 |
-
|
74 |
-
def apply_rotary_pos_emb(q, k, freqs, scale = 1):
|
75 |
-
q_len = q.shape[-2]
|
76 |
-
q_freqs = freqs[..., -q_len:, :]
|
77 |
-
inv_scale = scale ** -1
|
78 |
-
|
79 |
-
if scale.ndim == 2: scale = scale[-q_len:, :]
|
80 |
-
|
81 |
-
q = (q * q_freqs.cos() * scale) + (rotate_half(q) * q_freqs.sin() * scale)
|
82 |
-
k = (k * freqs.cos() * inv_scale) + (rotate_half(k) * freqs.sin() * inv_scale)
|
83 |
-
|
84 |
-
return q, k
|
85 |
-
|
86 |
-
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
87 |
-
return torch.log(torch.clamp(x, min=clip_val) * C)
|
88 |
-
|
89 |
-
def orthogonal_matrix_chunk(cols, qr_uniform_q=False, device=None):
|
90 |
-
unstructured_block = torch.randn((cols, cols), device=device)
|
91 |
-
|
92 |
-
q, r = torch.linalg.qr(unstructured_block.cpu(), mode="reduced")
|
93 |
-
q, r = map(lambda t: t.to(device), (q, r))
|
94 |
-
|
95 |
-
if qr_uniform_q:
|
96 |
-
d = torch.diag(r, 0)
|
97 |
-
q *= d.sign()
|
98 |
-
|
99 |
-
return q.t()
|
100 |
-
|
101 |
-
def linear_attention(q, k, v):
|
102 |
-
return einsum("...ed,...nd->...ne", k, q) if v is None else einsum("...de,...nd,...n->...ne", einsum("...nd,...ne->...de", k, v), q, 1.0 / (einsum("...nd,...d->...n", q, k.sum(dim=-2).type_as(q)) + 1e-8))
|
103 |
-
|
104 |
-
def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling=0, qr_uniform_q=False, device=None):
|
105 |
-
nb_full_blocks = int(nb_rows / nb_columns)
|
106 |
-
block_list = []
|
107 |
-
|
108 |
-
for _ in range(nb_full_blocks):
|
109 |
-
block_list.append(orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device))
|
110 |
-
|
111 |
-
remaining_rows = nb_rows - nb_full_blocks * nb_columns
|
112 |
-
if remaining_rows > 0: block_list.append(orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device)[:remaining_rows])
|
113 |
-
|
114 |
-
if scaling == 0: multiplier = torch.randn((nb_rows, nb_columns), device=device).norm(dim=1)
|
115 |
-
elif scaling == 1: multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device=device)
|
116 |
-
else: raise ValueError(f"{scaling} != 0, 1")
|
117 |
-
|
118 |
-
return torch.diag(multiplier) @ torch.cat(block_list)
|
119 |
-
|
120 |
-
def calc_same_padding(kernel_size):
|
121 |
-
pad = kernel_size // 2
|
122 |
-
return (pad, pad - (kernel_size + 1) % 2)
|
123 |
-
|
124 |
-
def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device=None):
|
125 |
-
b, h, *_ = data.shape
|
126 |
-
|
127 |
-
data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.0
|
128 |
-
ratio = projection_matrix.shape[0] ** -0.5
|
129 |
-
|
130 |
-
data_dash = torch.einsum("...id,...jd->...ij", (data_normalizer * data), repeat(projection_matrix, "j d -> b h j d", b=b, h=h).type_as(data))
|
131 |
-
diag_data = ((torch.sum(data**2, dim=-1) / 2.0) * (data_normalizer**2)).unsqueeze(dim=-1)
|
132 |
-
|
133 |
-
return (ratio * (torch.exp(data_dash - diag_data - torch.max(data_dash, dim=-1, keepdim=True).values) + eps) if is_query else ratio * (torch.exp(data_dash - diag_data + eps))).type_as(data)
|
134 |
-
|
135 |
-
def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False):
|
136 |
-
try:
|
137 |
-
data, sample_rate = sf.read(full_path, always_2d=True)
|
138 |
-
except Exception as e:
|
139 |
-
print(f"{full_path}: {e}")
|
140 |
-
|
141 |
-
if return_empty_on_exception: return [], sample_rate or target_sr or 48000
|
142 |
-
else: raise
|
143 |
-
|
144 |
-
data = data[:, 0] if len(data.shape) > 1 else data
|
145 |
-
assert len(data) > 2
|
146 |
-
|
147 |
-
max_mag = (-np.iinfo(data.dtype).min if np.issubdtype(data.dtype, np.integer) else max(np.amax(data), -np.amin(data)))
|
148 |
-
data = torch.FloatTensor(data.astype(np.float32)) / ((2**31) + 1 if max_mag > (2**15) else ((2**15) + 1 if max_mag > 1.01 else 1.0))
|
149 |
-
|
150 |
-
if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception: return [], sample_rate or target_sr or 48000
|
151 |
-
|
152 |
-
if target_sr is not None and sample_rate != target_sr:
|
153 |
-
data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sample_rate, target_sr=target_sr))
|
154 |
-
sample_rate = target_sr
|
155 |
-
|
156 |
-
return data, sample_rate
|
157 |
-
|
158 |
-
def torch_interp(x, xp, fp):
|
159 |
-
sort_idx = torch.argsort(xp)
|
160 |
-
|
161 |
-
xp = xp[sort_idx]
|
162 |
-
fp = fp[sort_idx]
|
163 |
-
|
164 |
-
right_idxs = torch.searchsorted(xp, x).clamp(max=len(xp) - 1)
|
165 |
-
left_idxs = (right_idxs - 1).clamp(min=0)
|
166 |
-
|
167 |
-
x_left = xp[left_idxs]
|
168 |
-
y_left = fp[left_idxs]
|
169 |
-
|
170 |
-
interp_vals = y_left + ((x - x_left) * (fp[right_idxs] - y_left) / (xp[right_idxs] - x_left))
|
171 |
-
interp_vals[x < xp[0]] = fp[0]
|
172 |
-
interp_vals[x > xp[-1]] = fp[-1]
|
173 |
-
|
174 |
-
return interp_vals
|
175 |
-
|
176 |
-
def batch_interp_with_replacement_detach(uv, f0):
|
177 |
-
result = f0.clone()
|
178 |
-
|
179 |
-
for i in range(uv.shape[0]):
|
180 |
-
interp_vals = torch_interp(torch.where(uv[i])[-1], torch.where(~uv[i])[-1], f0[i][~uv[i]]).detach()
|
181 |
-
result[i][uv[i]] = interp_vals
|
182 |
-
|
183 |
-
return result
|
184 |
-
|
185 |
-
def spawn_model(args):
|
186 |
-
return CFNaiveMelPE(input_channels=catch_none_args_must(args.mel.num_mels, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.mel.num_mels is None"), out_dims=catch_none_args_must(args.model.out_dims, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.out_dims is None"), hidden_dims=catch_none_args_must(args.model.hidden_dims, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.hidden_dims is None"), n_layers=catch_none_args_must(args.model.n_layers, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.n_layers is None"), n_heads=catch_none_args_must(args.model.n_heads, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.n_heads is None"), f0_max=catch_none_args_must(args.model.f0_max, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.f0_max is None"), f0_min=catch_none_args_must(args.model.f0_min, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.f0_min is None"), use_fa_norm=catch_none_args_must(args.model.use_fa_norm, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.use_fa_norm is None"), conv_only=catch_none_args_opti(args.model.conv_only, default=False, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.conv_only is None"), conv_dropout=catch_none_args_opti(args.model.conv_dropout, default=0.0, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.conv_dropout is None"), atten_dropout=catch_none_args_opti(args.model.atten_dropout, default=0.0, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.atten_dropout is None"), use_harmonic_emb=catch_none_args_opti(args.model.use_harmonic_emb, default=False, func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", warning_str="args.model.use_harmonic_emb is None"))
|
187 |
-
|
188 |
-
def catch_none_args_must(x, func_name, warning_str):
|
189 |
-
level = "ERROR"
|
190 |
-
|
191 |
-
if x is None:
|
192 |
-
print(f' [{level}] {warning_str}')
|
193 |
-
print(f' [{level}] > {func_name}')
|
194 |
-
raise ValueError(f' [{level}] {warning_str}')
|
195 |
-
else: return x
|
196 |
-
|
197 |
-
def catch_none_args_opti(x, default, func_name, warning_str=None, level='WARN'):
|
198 |
-
return default if x is None else x
|
199 |
-
|
200 |
-
def spawn_wav2mel(args, device = None):
|
201 |
-
_type = args.mel.type
|
202 |
-
|
203 |
-
if (str(_type).lower() == 'none') or (str(_type).lower() == 'default'): _type = 'default'
|
204 |
-
elif str(_type).lower() == 'stft': _type = 'stft'
|
205 |
-
|
206 |
-
wav2mel = Wav2MelModule(sr=catch_none_args_opti(args.mel.sr, default=16000, func_name='torchfcpe.tools.spawn_wav2mel', warning_str='args.mel.sr is None'), n_mels=catch_none_args_opti(args.mel.num_mels, default=128, func_name='torchfcpe.tools.spawn_wav2mel', warning_str='args.mel.num_mels is None'), n_fft=catch_none_args_opti(args.mel.n_fft, default=1024, func_name='torchfcpe.tools.spawn_wav2mel', warning_str='args.mel.n_fft is None'), win_size=catch_none_args_opti(args.mel.win_size, default=1024, func_name='torchfcpe.tools.spawn_wav2mel', warning_str='args.mel.win_size is None'), hop_length=catch_none_args_opti(args.mel.hop_size, default=160, func_name='torchfcpe.tools.spawn_wav2mel', warning_str='args.mel.hop_size is None'), fmin=catch_none_args_opti(args.mel.fmin, default=0, func_name='torchfcpe.tools.spawn_wav2mel', warning_str='args.mel.fmin is None'), fmax=catch_none_args_opti(args.mel.fmax, default=8000, func_name='torchfcpe.tools.spawn_wav2mel', warning_str='args.mel.fmax is None'), clip_val=1e-05, mel_type=_type)
|
207 |
-
device = catch_none_args_opti(device, default='cpu', func_name='torchfcpe.tools.spawn_wav2mel', warning_str='.device is None')
|
208 |
-
|
209 |
-
return wav2mel.to(torch.device(device))
|
210 |
-
|
211 |
-
def ensemble_f0(f0s, key_shift_list, tta_uv_penalty):
|
212 |
-
device = f0s.device
|
213 |
-
f0s = f0s / (torch.pow(2, torch.tensor(key_shift_list, device=device).to(device).unsqueeze(0).unsqueeze(0) / 12))
|
214 |
-
|
215 |
-
notes = torch.log2(f0s / 440) * 12 + 69
|
216 |
-
notes[notes < 0] = 0
|
217 |
-
|
218 |
-
uv_penalty = tta_uv_penalty**2
|
219 |
-
dp = torch.zeros_like(notes, device=device)
|
220 |
-
|
221 |
-
backtrack = torch.zeros_like(notes, device=device).long()
|
222 |
-
dp[:, 0, :] = (notes[:, 0, :] <= 0) * uv_penalty
|
223 |
-
|
224 |
-
for t in range(1, notes.size(1)):
|
225 |
-
penalty = torch.zeros([notes.size(0), notes.size(2), notes.size(2)], device=device)
|
226 |
-
t_uv = notes[:, t, :] <= 0
|
227 |
-
penalty += uv_penalty * t_uv.unsqueeze(1)
|
228 |
-
|
229 |
-
t1_uv = notes[:, t - 1, :] <= 0
|
230 |
-
l2 = torch.pow((notes[:, t - 1, :].unsqueeze(-1) - notes[:, t, :].unsqueeze(1)) * (~t1_uv).unsqueeze(-1) * (~t_uv).unsqueeze(1), 2) - 0.5
|
231 |
-
l2 = l2 * (l2 > 0)
|
232 |
-
|
233 |
-
penalty += l2
|
234 |
-
penalty += t1_uv.unsqueeze(-1) * (~t_uv).unsqueeze(1) * uv_penalty * 2
|
235 |
-
|
236 |
-
min_value, min_indices = torch.min(dp[:, t - 1, :].unsqueeze(-1) + penalty, dim=1)
|
237 |
-
dp[:, t, :] = min_value
|
238 |
-
backtrack[:, t, :] = min_indices
|
239 |
-
|
240 |
-
t = f0s.size(1) - 1
|
241 |
-
f0_result = torch.zeros_like(f0s[:, :, 0], device=device)
|
242 |
-
min_indices = torch.argmin(dp[:, t, :], dim=-1)
|
243 |
-
|
244 |
-
for i in range(0, t + 1):
|
245 |
-
f0_result[:, t - i] = f0s[:, t - i, min_indices]
|
246 |
-
min_indices = backtrack[:, t - i, min_indices]
|
247 |
-
|
248 |
-
return f0_result.unsqueeze(-1)
|
249 |
-
|
250 |
-
class LocalAttention(nn.Module):
|
251 |
-
def __init__(self, window_size, causal = False, look_backward = 1, look_forward = None, dropout = 0., shared_qk = False, rel_pos_emb_config = None, dim = None, autopad = False, exact_windowsize = False, scale = None, use_rotary_pos_emb = True, use_xpos = False, xpos_scale_base = None):
|
252 |
-
super().__init__()
|
253 |
-
look_forward = default(look_forward, 0 if causal else 1)
|
254 |
-
assert not (causal and look_forward > 0)
|
255 |
-
self.scale = scale
|
256 |
-
self.window_size = window_size
|
257 |
-
self.autopad = autopad
|
258 |
-
self.exact_windowsize = exact_windowsize
|
259 |
-
self.causal = causal
|
260 |
-
self.look_backward = look_backward
|
261 |
-
self.look_forward = look_forward
|
262 |
-
self.dropout = nn.Dropout(dropout)
|
263 |
-
self.shared_qk = shared_qk
|
264 |
-
self.rel_pos = None
|
265 |
-
self.use_xpos = use_xpos
|
266 |
-
|
267 |
-
if use_rotary_pos_emb and (exists(rel_pos_emb_config) or exists(dim)):
|
268 |
-
if exists(rel_pos_emb_config): dim = rel_pos_emb_config[0]
|
269 |
-
self.rel_pos = SinusoidalEmbeddings(dim, use_xpos = use_xpos, scale_base = default(xpos_scale_base, window_size // 2))
|
270 |
-
|
271 |
-
def forward(self, q, k, v, mask = None, input_mask = None, attn_bias = None, window_size = None):
|
272 |
-
mask = default(mask, input_mask)
|
273 |
-
assert not (exists(window_size) and not self.use_xpos)
|
274 |
-
|
275 |
-
_, autopad, pad_value, window_size, causal, look_backward, look_forward, shared_qk = q.shape, self.autopad, -1, default(window_size, self.window_size), self.causal, self.look_backward, self.look_forward, self.shared_qk
|
276 |
-
(q, packed_shape), (k, _), (v, _) = map(lambda t: pack([t], '* n d'), (q, k, v))
|
277 |
-
|
278 |
-
if autopad:
|
279 |
-
orig_seq_len = q.shape[1]
|
280 |
-
(_, q), (_, k), (_, v) = map(lambda t: pad_to_multiple(t, self.window_size, dim = -2), (q, k, v))
|
281 |
-
|
282 |
-
b, n, dim_head, device, dtype = *q.shape, q.device, q.dtype
|
283 |
-
scale = default(self.scale, dim_head ** -0.5)
|
284 |
-
|
285 |
-
assert (n % window_size) == 0
|
286 |
-
windows = n // window_size
|
287 |
-
|
288 |
-
if shared_qk: k = l2norm(k)
|
289 |
-
|
290 |
-
seq = torch.arange(n, device = device)
|
291 |
-
b_t = rearrange(seq, '(w n) -> 1 w n', w = windows, n = window_size)
|
292 |
-
bq, bk, bv = map(lambda t: rearrange(t, 'b (w n) d -> b w n d', w = windows), (q, k, v))
|
293 |
-
|
294 |
-
bq = bq * scale
|
295 |
-
look_around_kwargs = dict(backward = look_backward, forward = look_forward, pad_value = pad_value)
|
296 |
-
|
297 |
-
bk = look_around(bk, **look_around_kwargs)
|
298 |
-
bv = look_around(bv, **look_around_kwargs)
|
299 |
-
|
300 |
-
if exists(self.rel_pos):
|
301 |
-
pos_emb, xpos_scale = self.rel_pos(bk)
|
302 |
-
bq, bk = apply_rotary_pos_emb(bq, bk, pos_emb, scale = xpos_scale)
|
303 |
-
|
304 |
-
bq_t = b_t
|
305 |
-
bq_k = look_around(b_t, **look_around_kwargs)
|
306 |
-
|
307 |
-
bq_t = rearrange(bq_t, '... i -> ... i 1')
|
308 |
-
bq_k = rearrange(bq_k, '... j -> ... 1 j')
|
309 |
-
|
310 |
-
pad_mask = bq_k == pad_value
|
311 |
-
sim = einsum('b h i e, b h j e -> b h i j', bq, bk)
|
312 |
-
|
313 |
-
if exists(attn_bias):
|
314 |
-
heads = attn_bias.shape[0]
|
315 |
-
assert (b % heads) == 0
|
316 |
-
|
317 |
-
attn_bias = repeat(attn_bias, 'h i j -> (b h) 1 i j', b = b // heads)
|
318 |
-
sim = sim + attn_bias
|
319 |
-
|
320 |
-
mask_value = max_neg_value(sim)
|
321 |
-
|
322 |
-
if shared_qk:
|
323 |
-
self_mask = bq_t == bq_k
|
324 |
-
sim = sim.masked_fill(self_mask, -5e4)
|
325 |
-
del self_mask
|
326 |
-
|
327 |
-
if causal:
|
328 |
-
causal_mask = bq_t < bq_k
|
329 |
-
if self.exact_windowsize: causal_mask = causal_mask | (bq_t > (bq_k + (self.window_size * self.look_backward)))
|
330 |
-
sim = sim.masked_fill(causal_mask, mask_value)
|
331 |
-
del causal_mask
|
332 |
-
|
333 |
-
sim = sim.masked_fill(((bq_k - (self.window_size * self.look_forward)) > bq_t) | (bq_t > (bq_k + (self.window_size * self.look_backward))) | pad_mask, mask_value) if not causal and self.exact_windowsize else sim.masked_fill(pad_mask, mask_value)
|
334 |
-
|
335 |
-
if exists(mask):
|
336 |
-
batch = mask.shape[0]
|
337 |
-
assert (b % batch) == 0
|
338 |
-
|
339 |
-
h = b // mask.shape[0]
|
340 |
-
if autopad: _, mask = pad_to_multiple(mask, window_size, dim = -1, value = False)
|
341 |
-
|
342 |
-
mask = repeat(rearrange(look_around(rearrange(mask, '... (w n) -> (...) w n', w = windows, n = window_size), **{**look_around_kwargs, 'pad_value': False}), '... j -> ... 1 j'), 'b ... -> (b h) ...', h = h)
|
343 |
-
sim = sim.masked_fill(~mask, mask_value)
|
344 |
-
|
345 |
-
del mask
|
346 |
-
|
347 |
-
out = rearrange(einsum('b h i j, b h j e -> b h i e', self.dropout(sim.softmax(dim = -1)), bv), 'b w n d -> b (w n) d')
|
348 |
-
if autopad: out = out[:, :orig_seq_len, :]
|
349 |
-
|
350 |
-
out, *_ = unpack(out, packed_shape, '* n d')
|
351 |
-
return out
|
352 |
-
|
353 |
-
class SinusoidalEmbeddings(nn.Module):
|
354 |
-
def __init__(self, dim, scale_base = None, use_xpos = False, theta = 10000):
|
355 |
-
super().__init__()
|
356 |
-
inv_freq = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
357 |
-
self.register_buffer('inv_freq', inv_freq)
|
358 |
-
self.use_xpos = use_xpos
|
359 |
-
self.scale_base = scale_base
|
360 |
-
assert not (use_xpos and not exists(scale_base))
|
361 |
-
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
362 |
-
self.register_buffer('scale', scale, persistent = False)
|
363 |
-
|
364 |
-
def forward(self, x):
|
365 |
-
seq_len, device = x.shape[-2], x.device
|
366 |
-
t = torch.arange(seq_len, device = x.device).type_as(self.inv_freq)
|
367 |
-
|
368 |
-
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
|
369 |
-
freqs = torch.cat((freqs, freqs), dim = -1)
|
370 |
-
|
371 |
-
if not self.use_xpos: return freqs, torch.ones(1, device = device)
|
372 |
-
|
373 |
-
power = (t - (seq_len // 2)) / self.scale_base
|
374 |
-
scale = self.scale ** rearrange(power, 'n -> n 1')
|
375 |
-
|
376 |
-
return freqs, torch.cat((scale, scale), dim = -1)
|
377 |
-
|
378 |
-
class STFT:
|
379 |
-
def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5):
|
380 |
-
self.target_sr = sr
|
381 |
-
self.n_mels = n_mels
|
382 |
-
self.n_fft = n_fft
|
383 |
-
self.win_size = win_size
|
384 |
-
self.hop_length = hop_length
|
385 |
-
self.fmin = fmin
|
386 |
-
self.fmax = fmax
|
387 |
-
self.clip_val = clip_val
|
388 |
-
self.mel_basis = {}
|
389 |
-
self.hann_window = {}
|
390 |
-
|
391 |
-
def get_mel(self, y, keyshift=0, speed=1, center=False, train=False):
|
392 |
-
n_fft = self.n_fft
|
393 |
-
win_size = self.win_size
|
394 |
-
hop_length = self.hop_length
|
395 |
-
fmax = self.fmax
|
396 |
-
factor = 2 ** (keyshift / 12)
|
397 |
-
win_size_new = int(np.round(win_size * factor))
|
398 |
-
hop_length_new = int(np.round(hop_length * speed))
|
399 |
-
mel_basis = self.mel_basis if not train else {}
|
400 |
-
hann_window = self.hann_window if not train else {}
|
401 |
-
mel_basis_key = str(fmax) + "_" + str(y.device)
|
402 |
-
|
403 |
-
if mel_basis_key not in mel_basis: mel_basis[mel_basis_key] = torch.from_numpy(librosa_mel_fn(sr=self.target_sr, n_fft=n_fft, n_mels=self.n_mels, fmin=self.fmin, fmax=fmax)).float().to(y.device)
|
404 |
-
keyshift_key = str(keyshift) + "_" + str(y.device)
|
405 |
-
if keyshift_key not in hann_window: hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device)
|
406 |
-
|
407 |
-
pad_left = (win_size_new - hop_length_new) // 2
|
408 |
-
pad_right = max((win_size_new - hop_length_new + 1) // 2, win_size_new - y.size(-1) - pad_left)
|
409 |
-
|
410 |
-
spec = torch.stft(F.pad(y.unsqueeze(1), (pad_left, pad_right), mode="reflect" if pad_right < y.size(-1) else "constant").squeeze(1), int(np.round(n_fft * factor)), hop_length=hop_length_new, win_length=win_size_new, window=hann_window[keyshift_key], center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
|
411 |
-
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + (1e-9))
|
412 |
-
|
413 |
-
if keyshift != 0:
|
414 |
-
size = n_fft // 2 + 1
|
415 |
-
resize = spec.size(1)
|
416 |
-
spec = (F.pad(spec, (0, 0, 0, size - resize)) if resize < size else spec[:, :size, :]) * win_size / win_size_new
|
417 |
-
|
418 |
-
return dynamic_range_compression_torch(torch.matmul(mel_basis[mel_basis_key], spec), clip_val=self.clip_val)
|
419 |
-
|
420 |
-
def __call__(self, audiopath):
|
421 |
-
audio, _ = load_wav_to_torch(audiopath, target_sr=self.target_sr)
|
422 |
-
return self.get_mel(audio.unsqueeze(0)).squeeze(0)
|
423 |
-
|
424 |
-
class PCmer(nn.Module):
|
425 |
-
def __init__(self, num_layers, num_heads, dim_model, dim_keys, dim_values, residual_dropout, attention_dropout):
|
426 |
-
super().__init__()
|
427 |
-
self.num_layers = num_layers
|
428 |
-
self.num_heads = num_heads
|
429 |
-
self.dim_model = dim_model
|
430 |
-
self.dim_values = dim_values
|
431 |
-
self.dim_keys = dim_keys
|
432 |
-
self.residual_dropout = residual_dropout
|
433 |
-
self.attention_dropout = attention_dropout
|
434 |
-
self._layers = nn.ModuleList([_EncoderLayer(self) for _ in range(num_layers)])
|
435 |
-
|
436 |
-
def forward(self, phone, mask=None):
|
437 |
-
for layer in self._layers:
|
438 |
-
phone = layer(phone, mask)
|
439 |
-
|
440 |
-
return phone
|
441 |
-
|
442 |
-
class _EncoderLayer(nn.Module):
|
443 |
-
def __init__(self, parent):
|
444 |
-
super().__init__()
|
445 |
-
self.conformer = ConformerConvModule_LEGACY(parent.dim_model)
|
446 |
-
self.norm = nn.LayerNorm(parent.dim_model)
|
447 |
-
self.dropout = nn.Dropout(parent.residual_dropout)
|
448 |
-
self.attn = SelfAttention(dim=parent.dim_model, heads=parent.num_heads, causal=False)
|
449 |
-
|
450 |
-
def forward(self, phone, mask=None):
|
451 |
-
phone = phone + (self.attn(self.norm(phone), mask=mask))
|
452 |
-
return phone + (self.conformer(phone))
|
453 |
-
|
454 |
-
class ConformerNaiveEncoder(nn.Module):
|
455 |
-
def __init__(self, num_layers, num_heads, dim_model, use_norm = False, conv_only = False, conv_dropout = 0, atten_dropout = 0):
|
456 |
-
super().__init__()
|
457 |
-
self.num_layers = num_layers
|
458 |
-
self.num_heads = num_heads
|
459 |
-
self.dim_model = dim_model
|
460 |
-
self.use_norm = use_norm
|
461 |
-
self.residual_dropout = 0.1
|
462 |
-
self.attention_dropout = 0.1
|
463 |
-
self.encoder_layers = nn.ModuleList([CFNEncoderLayer(dim_model, num_heads, use_norm, conv_only, conv_dropout, atten_dropout) for _ in range(num_layers)])
|
464 |
-
|
465 |
-
def forward(self, x, mask=None):
|
466 |
-
for (_, layer) in enumerate(self.encoder_layers):
|
467 |
-
x = layer(x, mask)
|
468 |
-
|
469 |
-
return x
|
470 |
-
|
471 |
-
class CFNaiveMelPE(nn.Module):
|
472 |
-
def __init__(self, input_channels, out_dims, hidden_dims = 512, n_layers = 6, n_heads = 8, f0_max = 1975.5, f0_min = 32.70, use_fa_norm = False, conv_only = False, conv_dropout = 0, atten_dropout = 0, use_harmonic_emb = False):
|
473 |
-
super().__init__()
|
474 |
-
self.input_channels = input_channels
|
475 |
-
self.out_dims = out_dims
|
476 |
-
self.hidden_dims = hidden_dims
|
477 |
-
self.n_layers = n_layers
|
478 |
-
self.n_heads = n_heads
|
479 |
-
self.f0_max = f0_max
|
480 |
-
self.f0_min = f0_min
|
481 |
-
self.use_fa_norm = use_fa_norm
|
482 |
-
self.residual_dropout = 0.1
|
483 |
-
self.attention_dropout = 0.1
|
484 |
-
self.harmonic_emb = nn.Embedding(9, hidden_dims) if use_harmonic_emb else None
|
485 |
-
self.input_stack = nn.Sequential(nn.Conv1d(input_channels, hidden_dims, 3, 1, 1), nn.GroupNorm(4, hidden_dims), nn.LeakyReLU(), nn.Conv1d(hidden_dims, hidden_dims, 3, 1, 1))
|
486 |
-
self.net = ConformerNaiveEncoder(num_layers=n_layers, num_heads=n_heads, dim_model=hidden_dims, use_norm=use_fa_norm, conv_only=conv_only, conv_dropout=conv_dropout, atten_dropout=atten_dropout)
|
487 |
-
self.norm = nn.LayerNorm(hidden_dims)
|
488 |
-
self.output_proj = weight_norm(nn.Linear(hidden_dims, out_dims))
|
489 |
-
self.cent_table_b = torch.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0], out_dims).detach()
|
490 |
-
self.register_buffer("cent_table", self.cent_table_b)
|
491 |
-
self.gaussian_blurred_cent_mask_b = (1200 * torch.log2(torch.Tensor([self.f0_max / 10.])))[0].detach()
|
492 |
-
self.register_buffer("gaussian_blurred_cent_mask", self.gaussian_blurred_cent_mask_b)
|
493 |
-
|
494 |
-
def forward(self, x, _h_emb=None):
|
495 |
-
x = self.input_stack(x.transpose(-1, -2)).transpose(-1, -2)
|
496 |
-
if self.harmonic_emb is not None: x = x + self.harmonic_emb(torch.LongTensor([0]).to(x.device)) if _h_emb is None else x + self.harmonic_emb(torch.LongTensor([int(_h_emb)]).to(x.device))
|
497 |
-
|
498 |
-
return torch.sigmoid(self.output_proj(self.norm(self.net(x))))
|
499 |
-
|
500 |
-
@torch.no_grad()
|
501 |
-
def latent2cents_decoder(self, y, threshold = 0.05, mask = True):
|
502 |
-
B, N, _ = y.size()
|
503 |
-
ci = self.cent_table[None, None, :].expand(B, N, -1)
|
504 |
-
rtn = torch.sum(ci * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True)
|
505 |
-
|
506 |
-
if mask:
|
507 |
-
confident = torch.max(y, dim=-1, keepdim=True)[0]
|
508 |
-
confident_mask = torch.ones_like(confident)
|
509 |
-
confident_mask[confident <= threshold] = float("-INF")
|
510 |
-
rtn = rtn * confident_mask
|
511 |
-
|
512 |
-
return rtn
|
513 |
-
|
514 |
-
@torch.no_grad()
|
515 |
-
def latent2cents_local_decoder(self, y, threshold = 0.05, mask = True):
|
516 |
-
B, N, _ = y.size()
|
517 |
-
ci = self.cent_table[None, None, :].expand(B, N, -1)
|
518 |
-
confident, max_index = torch.max(y, dim=-1, keepdim=True)
|
519 |
-
|
520 |
-
local_argmax_index = torch.arange(0, 9).to(max_index.device) + (max_index - 4)
|
521 |
-
local_argmax_index[local_argmax_index < 0] = 0
|
522 |
-
local_argmax_index[local_argmax_index >= self.out_dims] = self.out_dims - 1
|
523 |
-
|
524 |
-
y_l = torch.gather(y, -1, local_argmax_index)
|
525 |
-
rtn = torch.sum(torch.gather(ci, -1, local_argmax_index) * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True)
|
526 |
-
|
527 |
-
if mask:
|
528 |
-
confident_mask = torch.ones_like(confident)
|
529 |
-
confident_mask[confident <= threshold] = float("-INF")
|
530 |
-
|
531 |
-
rtn = rtn * confident_mask
|
532 |
-
|
533 |
-
return rtn
|
534 |
-
|
535 |
-
@torch.no_grad()
|
536 |
-
def infer(self, mel, decoder = "local_argmax", threshold = 0.05):
|
537 |
-
latent = self.forward(mel)
|
538 |
-
|
539 |
-
if decoder == "argmax": cents = self.latent2cents_local_decoder
|
540 |
-
elif decoder == "local_argmax": cents = self.latent2cents_local_decoder
|
541 |
-
|
542 |
-
return self.cent_to_f0(cents(latent, threshold=threshold))
|
543 |
-
|
544 |
-
@torch.no_grad()
|
545 |
-
def cent_to_f0(self, cent: torch.Tensor) -> torch.Tensor:
|
546 |
-
return 10 * 2 ** (cent / 1200)
|
547 |
-
|
548 |
-
@torch.no_grad()
|
549 |
-
def f0_to_cent(self, f0):
|
550 |
-
return 1200 * torch.log2(f0 / 10)
|
551 |
-
|
552 |
-
class CFNEncoderLayer(nn.Module):
|
553 |
-
def __init__(self, dim_model, num_heads = 8, use_norm = False, conv_only = False, conv_dropout = 0, atten_dropout = 0):
|
554 |
-
super().__init__()
|
555 |
-
|
556 |
-
self.conformer = nn.Sequential(ConformerConvModule(dim_model), nn.Dropout(conv_dropout)) if conv_dropout > 0 else ConformerConvModule(dim_model)
|
557 |
-
self.norm = nn.LayerNorm(dim_model)
|
558 |
-
|
559 |
-
self.dropout = nn.Dropout(0.1)
|
560 |
-
self.attn = SelfAttention(dim=dim_model, heads=num_heads, causal=False, use_norm=use_norm, dropout=atten_dropout) if not conv_only else None
|
561 |
-
|
562 |
-
def forward(self, x, mask=None):
|
563 |
-
if self.attn is not None: x = x + (self.attn(self.norm(x), mask=mask))
|
564 |
-
return x + (self.conformer(x))
|
565 |
-
|
566 |
-
class Swish(nn.Module):
|
567 |
-
def forward(self, x):
|
568 |
-
return x * x.sigmoid()
|
569 |
-
|
570 |
-
class Transpose(nn.Module):
|
571 |
-
def __init__(self, dims):
|
572 |
-
super().__init__()
|
573 |
-
assert len(dims) == 2, "dims == 2"
|
574 |
-
|
575 |
-
self.dims = dims
|
576 |
-
|
577 |
-
def forward(self, x):
|
578 |
-
return x.transpose(*self.dims)
|
579 |
-
|
580 |
-
class GLU(nn.Module):
|
581 |
-
def __init__(self, dim):
|
582 |
-
super().__init__()
|
583 |
-
self.dim = dim
|
584 |
-
|
585 |
-
def forward(self, x):
|
586 |
-
out, gate = x.chunk(2, dim=self.dim)
|
587 |
-
return out * gate.sigmoid()
|
588 |
-
|
589 |
-
class DepthWiseConv1d_LEGACY(nn.Module):
|
590 |
-
def __init__(self, chan_in, chan_out, kernel_size, padding):
|
591 |
-
super().__init__()
|
592 |
-
self.padding = padding
|
593 |
-
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in)
|
594 |
-
|
595 |
-
def forward(self, x):
|
596 |
-
return self.conv(F.pad(x, self.padding))
|
597 |
-
|
598 |
-
class DepthWiseConv1d(nn.Module):
|
599 |
-
def __init__(self, chan_in, chan_out, kernel_size, padding, groups):
|
600 |
-
super().__init__()
|
601 |
-
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size=kernel_size, padding=padding, groups=groups)
|
602 |
-
|
603 |
-
def forward(self, x):
|
604 |
-
return self.conv(x)
|
605 |
-
|
606 |
-
class ConformerConvModule_LEGACY(nn.Module):
|
607 |
-
def __init__(self, dim, causal=False, expansion_factor=2, kernel_size=31, dropout=0.0):
|
608 |
-
super().__init__()
|
609 |
-
inner_dim = dim * expansion_factor
|
610 |
-
self.net = nn.Sequential(nn.LayerNorm(dim), Transpose((1, 2)), nn.Conv1d(dim, inner_dim * 2, 1), GLU(dim=1), DepthWiseConv1d_LEGACY(inner_dim, inner_dim, kernel_size=kernel_size, padding=(calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0))), Swish(), nn.Conv1d(inner_dim, dim, 1), Transpose((1, 2)), nn.Dropout(dropout))
|
611 |
-
|
612 |
-
def forward(self, x):
|
613 |
-
return self.net(x)
|
614 |
-
|
615 |
-
class ConformerConvModule(nn.Module):
|
616 |
-
def __init__(self, dim, expansion_factor=2, kernel_size=31, dropout=0):
|
617 |
-
super().__init__()
|
618 |
-
inner_dim = dim * expansion_factor
|
619 |
-
|
620 |
-
self.net = nn.Sequential(nn.LayerNorm(dim), Transpose((1, 2)), nn.Conv1d(dim, inner_dim * 2, 1), nn.GLU(dim=1), DepthWiseConv1d(inner_dim, inner_dim, kernel_size=kernel_size, padding=calc_same_padding(kernel_size)[0], groups=inner_dim), nn.SiLU(), nn.Conv1d(inner_dim, dim, 1), Transpose((1, 2)), nn.Dropout(dropout))
|
621 |
-
|
622 |
-
def forward(self, x):
|
623 |
-
return self.net(x)
|
624 |
-
|
625 |
-
class FastAttention(nn.Module):
|
626 |
-
def __init__(self, dim_heads, nb_features=None, ortho_scaling=0, causal=False, generalized_attention=False, kernel_fn=nn.ReLU(), qr_uniform_q=False, no_projection=False):
|
627 |
-
super().__init__()
|
628 |
-
nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
|
629 |
-
self.dim_heads = dim_heads
|
630 |
-
self.nb_features = nb_features
|
631 |
-
self.ortho_scaling = ortho_scaling
|
632 |
-
self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows=self.nb_features, nb_columns=dim_heads, scaling=ortho_scaling, qr_uniform_q=qr_uniform_q)
|
633 |
-
projection_matrix = self.create_projection()
|
634 |
-
self.register_buffer("projection_matrix", projection_matrix)
|
635 |
-
self.generalized_attention = generalized_attention
|
636 |
-
self.kernel_fn = kernel_fn
|
637 |
-
self.no_projection = no_projection
|
638 |
-
self.causal = causal
|
639 |
-
|
640 |
-
@torch.no_grad()
|
641 |
-
def redraw_projection_matrix(self):
|
642 |
-
projections = self.create_projection()
|
643 |
-
self.projection_matrix.copy_(projections)
|
644 |
-
|
645 |
-
del projections
|
646 |
-
|
647 |
-
def forward(self, q, k, v):
|
648 |
-
if self.no_projection: q, k = q.softmax(dim=-1), (torch.exp(k) if self.causal else k.softmax(dim=-2))
|
649 |
-
else:
|
650 |
-
create_kernel = partial(softmax_kernel, projection_matrix=self.projection_matrix, device=q.device)
|
651 |
-
q, k = create_kernel(q, is_query=True), create_kernel(k, is_query=False)
|
652 |
-
|
653 |
-
attn_fn = linear_attention if not self.causal else self.causal_linear_fn
|
654 |
-
return attn_fn(q, k, None) if v is None else attn_fn(q, k, v)
|
655 |
-
|
656 |
-
class SelfAttention(nn.Module):
|
657 |
-
def __init__(self, dim, causal=False, heads=8, dim_head=64, local_heads=0, local_window_size=256, nb_features=None, feature_redraw_interval=1000, generalized_attention=False, kernel_fn=nn.ReLU(), qr_uniform_q=False, dropout=0.0, no_projection=False):
|
658 |
-
super().__init__()
|
659 |
-
assert dim % heads == 0
|
660 |
-
dim_head = default(dim_head, dim // heads)
|
661 |
-
inner_dim = dim_head * heads
|
662 |
-
self.fast_attention = FastAttention(dim_head, nb_features, causal=causal, generalized_attention=generalized_attention, kernel_fn=kernel_fn, qr_uniform_q=qr_uniform_q, no_projection=no_projection)
|
663 |
-
self.heads = heads
|
664 |
-
self.global_heads = heads - local_heads
|
665 |
-
self.local_attn = (LocalAttention(window_size=local_window_size, causal=causal, autopad=True, dropout=dropout, look_forward=int(not causal), rel_pos_emb_config=(dim_head, local_heads)) if local_heads > 0 else None)
|
666 |
-
self.to_q = nn.Linear(dim, inner_dim)
|
667 |
-
self.to_k = nn.Linear(dim, inner_dim)
|
668 |
-
self.to_v = nn.Linear(dim, inner_dim)
|
669 |
-
self.to_out = nn.Linear(inner_dim, dim)
|
670 |
-
self.dropout = nn.Dropout(dropout)
|
671 |
-
|
672 |
-
@torch.no_grad()
|
673 |
-
def redraw_projection_matrix(self):
|
674 |
-
self.fast_attention.redraw_projection_matrix()
|
675 |
-
|
676 |
-
def forward(self, x, context=None, mask=None, context_mask=None, name=None, inference=False, **kwargs):
|
677 |
-
_, _, _, h, gh = *x.shape, self.heads, self.global_heads
|
678 |
-
cross_attend = exists(context)
|
679 |
-
|
680 |
-
context = default(context, x)
|
681 |
-
context_mask = default(context_mask, mask) if not cross_attend else context_mask
|
682 |
-
|
683 |
-
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (self.to_q(x), self.to_k(context), self.to_v(context)))
|
684 |
-
(q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v))
|
685 |
-
|
686 |
-
attn_outs = []
|
687 |
-
|
688 |
-
if not empty(q):
|
689 |
-
if exists(context_mask): v.masked_fill_(~context_mask[:, None, :, None], 0.0)
|
690 |
-
|
691 |
-
if cross_attend: pass
|
692 |
-
else: out = self.fast_attention(q, k, v)
|
693 |
-
|
694 |
-
attn_outs.append(out)
|
695 |
-
|
696 |
-
if not empty(lq):
|
697 |
-
assert (not cross_attend), "not cross_attend"
|
698 |
-
|
699 |
-
out = self.local_attn(lq, lk, lv, input_mask=mask)
|
700 |
-
attn_outs.append(out)
|
701 |
-
|
702 |
-
return self.dropout(self.to_out(rearrange(torch.cat(attn_outs, dim=1), "b h n d -> b n (h d)")))
|
703 |
-
|
704 |
-
class HannWindow(torch.nn.Module):
|
705 |
-
def __init__(self, win_size):
|
706 |
-
super().__init__()
|
707 |
-
self.register_buffer('window', torch.hann_window(win_size), persistent=False)
|
708 |
-
|
709 |
-
def forward(self):
|
710 |
-
return self.window
|
711 |
-
|
712 |
-
class FCPE_LEGACY(nn.Module):
|
713 |
-
def __init__(self, input_channel=128, out_dims=360, n_layers=12, n_chans=512, use_siren=False, use_full=False, loss_mse_scale=10, loss_l2_regularization=False, loss_l2_regularization_scale=1, loss_grad1_mse=False, loss_grad1_mse_scale=1, f0_max=1975.5, f0_min=32.70, confidence=False, threshold=0.05, use_input_conv=True):
|
714 |
-
super().__init__()
|
715 |
-
if use_siren: raise ValueError("Siren not support")
|
716 |
-
if use_full: raise ValueError("Model full not support")
|
717 |
-
|
718 |
-
self.loss_mse_scale = loss_mse_scale if (loss_mse_scale is not None) else 10
|
719 |
-
self.loss_l2_regularization = (loss_l2_regularization if (loss_l2_regularization is not None) else False)
|
720 |
-
self.loss_l2_regularization_scale = (loss_l2_regularization_scale if (loss_l2_regularization_scale is not None) else 1)
|
721 |
-
self.loss_grad1_mse = loss_grad1_mse if (loss_grad1_mse is not None) else False
|
722 |
-
self.loss_grad1_mse_scale = (loss_grad1_mse_scale if (loss_grad1_mse_scale is not None) else 1)
|
723 |
-
self.f0_max = f0_max if (f0_max is not None) else 1975.5
|
724 |
-
self.f0_min = f0_min if (f0_min is not None) else 32.70
|
725 |
-
self.confidence = confidence if (confidence is not None) else False
|
726 |
-
self.threshold = threshold if (threshold is not None) else 0.05
|
727 |
-
self.use_input_conv = use_input_conv if (use_input_conv is not None) else True
|
728 |
-
self.cent_table_b = torch.Tensor(np.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0], out_dims))
|
729 |
-
self.register_buffer("cent_table", self.cent_table_b)
|
730 |
-
self.stack = nn.Sequential(nn.Conv1d(input_channel, n_chans, 3, 1, 1), nn.GroupNorm(4, n_chans), nn.LeakyReLU(), nn.Conv1d(n_chans, n_chans, 3, 1, 1))
|
731 |
-
self.decoder = PCmer(num_layers=n_layers, num_heads=8, dim_model=n_chans, dim_keys=n_chans, dim_values=n_chans, residual_dropout=0.1, attention_dropout=0.1)
|
732 |
-
self.norm = nn.LayerNorm(n_chans)
|
733 |
-
self.n_out = out_dims
|
734 |
-
self.dense_out = weight_norm(nn.Linear(n_chans, self.n_out))
|
735 |
-
|
736 |
-
def forward(self, mel, infer=True, gt_f0=None, return_hz_f0=False, cdecoder="local_argmax"):
|
737 |
-
if cdecoder == "argmax": self.cdecoder = self.cents_decoder
|
738 |
-
elif cdecoder == "local_argmax": self.cdecoder = self.cents_local_decoder
|
739 |
-
|
740 |
-
x = torch.sigmoid(self.dense_out(self.norm(self.decoder((self.stack(mel.transpose(1, 2)).transpose(1, 2) if self.use_input_conv else mel)))))
|
741 |
-
|
742 |
-
if not infer:
|
743 |
-
loss_all = self.loss_mse_scale * F.binary_cross_entropy(x, self.gaussian_blurred_cent(self.f0_to_cent(gt_f0)))
|
744 |
-
if self.loss_l2_regularization: loss_all = loss_all + l2_regularization(model=self, l2_alpha=self.loss_l2_regularization_scale)
|
745 |
-
x = loss_all
|
746 |
-
|
747 |
-
if infer:
|
748 |
-
x = self.cent_to_f0(self.cdecoder(x))
|
749 |
-
x = (1 + x / 700).log() if not return_hz_f0 else x
|
750 |
-
|
751 |
-
return x
|
752 |
-
|
753 |
-
def cents_decoder(self, y, mask=True):
|
754 |
-
B, N, _ = y.size()
|
755 |
-
rtn = torch.sum(self.cent_table[None, None, :].expand(B, N, -1) * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True)
|
756 |
-
|
757 |
-
if mask:
|
758 |
-
confident = torch.max(y, dim=-1, keepdim=True)[0]
|
759 |
-
confident_mask = torch.ones_like(confident)
|
760 |
-
|
761 |
-
confident_mask[confident <= self.threshold] = float("-INF")
|
762 |
-
rtn = rtn * confident_mask
|
763 |
-
|
764 |
-
return (rtn, confident) if self.confidence else rtn
|
765 |
-
|
766 |
-
def cents_local_decoder(self, y, mask=True):
|
767 |
-
B, N, _ = y.size()
|
768 |
-
|
769 |
-
confident, max_index = torch.max(y, dim=-1, keepdim=True)
|
770 |
-
local_argmax_index = torch.clamp(torch.arange(0, 9).to(max_index.device) + (max_index - 4), 0, self.n_out - 1)
|
771 |
-
|
772 |
-
y_l = torch.gather(y, -1, local_argmax_index)
|
773 |
-
rtn = torch.sum(torch.gather(self.cent_table[None, None, :].expand(B, N, -1), -1, local_argmax_index) * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True)
|
774 |
-
|
775 |
-
if mask:
|
776 |
-
confident_mask = torch.ones_like(confident)
|
777 |
-
confident_mask[confident <= self.threshold] = float("-INF")
|
778 |
-
rtn = rtn * confident_mask
|
779 |
-
|
780 |
-
return (rtn, confident) if self.confidence else rtn
|
781 |
-
|
782 |
-
def cent_to_f0(self, cent):
|
783 |
-
return 10.0 * 2 ** (cent / 1200.0)
|
784 |
-
|
785 |
-
def f0_to_cent(self, f0):
|
786 |
-
return 1200.0 * torch.log2(f0 / 10.0)
|
787 |
-
|
788 |
-
def gaussian_blurred_cent(self, cents):
|
789 |
-
B, N, _ = cents.size()
|
790 |
-
return torch.exp(-torch.square(self.cent_table[None, None, :].expand(B, N, -1) - cents) / 1250) * (cents > 0.1) & (cents < (1200.0 * np.log2(self.f0_max / 10.0))).float()
|
791 |
-
|
792 |
-
class InferCFNaiveMelPE(torch.nn.Module):
|
793 |
-
def __init__(self, args, state_dict):
|
794 |
-
super().__init__()
|
795 |
-
self.wav2mel = spawn_wav2mel(args, device="cpu")
|
796 |
-
self.model = spawn_model(args)
|
797 |
-
self.model.load_state_dict(state_dict)
|
798 |
-
self.model.eval()
|
799 |
-
self.args_dict = dict(args)
|
800 |
-
self.register_buffer("tensor_device_marker", torch.tensor(1.0).float(), persistent=False)
|
801 |
-
|
802 |
-
def forward(self, wav, sr, decoder_mode = "local_argmax", threshold = 0.006, key_shifts = [0]):
|
803 |
-
with torch.no_grad():
|
804 |
-
mels = rearrange(torch.stack([self.wav2mel(wav.to(self.tensor_device_marker.device), sr, keyshift=keyshift) for keyshift in key_shifts], -1), "B T C K -> (B K) T C")
|
805 |
-
f0s = rearrange(self.model.infer(mels, decoder=decoder_mode, threshold=threshold), "(B K) T 1 -> B T (K 1)", K=len(key_shifts))
|
806 |
-
|
807 |
-
return f0s
|
808 |
-
|
809 |
-
def infer(self, wav, sr, decoder_mode = "local_argmax", threshold = 0.006, f0_min = None, f0_max = None, interp_uv = False, output_interp_target_length = None, return_uv = False, test_time_augmentation = False, tta_uv_penalty = 12.0, tta_key_shifts = [0, -12, 12], tta_use_origin_uv=False):
|
810 |
-
if test_time_augmentation:
|
811 |
-
assert len(tta_key_shifts) > 0
|
812 |
-
flag = 0
|
813 |
-
|
814 |
-
if tta_use_origin_uv:
|
815 |
-
if 0 not in tta_key_shifts:
|
816 |
-
flag = 1
|
817 |
-
tta_key_shifts.append(0)
|
818 |
-
|
819 |
-
tta_key_shifts.sort(key=lambda x: (x if x >= 0 else -x / 2))
|
820 |
-
f0s = self.__call__(wav, sr, decoder_mode, threshold, tta_key_shifts)
|
821 |
-
f0 = ensemble_f0(f0s[:, :, flag:], tta_key_shifts[flag:], tta_uv_penalty)
|
822 |
-
|
823 |
-
f0_for_uv = f0s[:, :, [0]] if tta_use_origin_uv else f0
|
824 |
-
else:
|
825 |
-
f0 = self.__call__(wav, sr, decoder_mode, threshold)
|
826 |
-
f0_for_uv = f0
|
827 |
-
|
828 |
-
if f0_min is None: f0_min = self.args_dict["model"]["f0_min"]
|
829 |
-
|
830 |
-
uv = (f0_for_uv < f0_min).type(f0_for_uv.dtype)
|
831 |
-
f0 = f0 * (1 - uv)
|
832 |
-
|
833 |
-
if interp_uv: f0 = batch_interp_with_replacement_detach(uv.squeeze(-1).bool(), f0.squeeze(-1)).unsqueeze(-1)
|
834 |
-
if f0_max is not None: f0[f0 > f0_max] = f0_max
|
835 |
-
if output_interp_target_length is not None: f0 = F.interpolate(f0.transpose(1, 2), size=int(output_interp_target_length), mode="nearest").transpose(1, 2)
|
836 |
-
|
837 |
-
if return_uv: return f0, F.interpolate(uv.transpose(1, 2), size=int(output_interp_target_length), mode="nearest").transpose(1, 2)
|
838 |
-
else: return f0
|
839 |
-
|
840 |
-
class FCPEInfer_LEGACY:
|
841 |
-
def __init__(self, model_path, device=None, dtype=torch.float32, providers=None, onnx=False):
|
842 |
-
if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
|
843 |
-
self.device = device
|
844 |
-
self.dtype = dtype
|
845 |
-
self.onnx = onnx
|
846 |
-
|
847 |
-
if self.onnx:
|
848 |
-
sess_options = ort.SessionOptions()
|
849 |
-
sess_options.log_severity_level = 3
|
850 |
-
|
851 |
-
self.model = ort.InferenceSession(decrypt_model(model_path), sess_options=sess_options, providers=providers)
|
852 |
-
else:
|
853 |
-
ckpt = torch.load(model_path, map_location=torch.device(self.device))
|
854 |
-
self.args = DotDict(ckpt["config"])
|
855 |
-
|
856 |
-
model = FCPE_LEGACY(input_channel=self.args.model.input_channel, out_dims=self.args.model.out_dims, n_layers=self.args.model.n_layers, n_chans=self.args.model.n_chans, use_siren=self.args.model.use_siren, use_full=self.args.model.use_full, loss_mse_scale=self.args.loss.loss_mse_scale, loss_l2_regularization=self.args.loss.loss_l2_regularization, loss_l2_regularization_scale=self.args.loss.loss_l2_regularization_scale, loss_grad1_mse=self.args.loss.loss_grad1_mse, loss_grad1_mse_scale=self.args.loss.loss_grad1_mse_scale, f0_max=self.args.model.f0_max, f0_min=self.args.model.f0_min, confidence=self.args.model.confidence)
|
857 |
-
model.to(self.device).to(self.dtype)
|
858 |
-
model.load_state_dict(ckpt["model"])
|
859 |
-
|
860 |
-
model.eval()
|
861 |
-
self.model = model
|
862 |
-
|
863 |
-
@torch.no_grad()
|
864 |
-
def __call__(self, audio, sr, threshold=0.05):
|
865 |
-
if not self.onnx: self.model.threshold = threshold
|
866 |
-
self.wav2mel = Wav2Mel(device=self.device, dtype=self.dtype)
|
867 |
-
|
868 |
-
return (torch.as_tensor(self.model.run([self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype).detach().cpu().numpy(), self.model.get_inputs()[1].name: np.array(threshold, dtype=np.float32)})[0], dtype=self.dtype, device=self.device) if self.onnx else self.model(mel=self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype), infer=True, return_hz_f0=True))
|
869 |
-
|
870 |
-
class FCPEInfer:
|
871 |
-
def __init__(self, model_path, device=None, dtype=torch.float32, providers=None, onnx=False):
|
872 |
-
if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
|
873 |
-
self.device = device
|
874 |
-
self.dtype = dtype
|
875 |
-
self.onnx = onnx
|
876 |
-
|
877 |
-
if self.onnx:
|
878 |
-
sess_options = ort.SessionOptions()
|
879 |
-
sess_options.log_severity_level = 3
|
880 |
-
|
881 |
-
self.model = ort.InferenceSession(decrypt_model(model_path), sess_options=sess_options, providers=providers)
|
882 |
-
else:
|
883 |
-
ckpt = torch.load(model_path, map_location=torch.device(device))
|
884 |
-
ckpt["config_dict"]["model"]["conv_dropout"] = ckpt["config_dict"]["model"]["atten_dropout"] = 0.0
|
885 |
-
self.args = DotDict(ckpt["config_dict"])
|
886 |
-
|
887 |
-
model = InferCFNaiveMelPE(self.args, ckpt["model"])
|
888 |
-
model = model.to(device)
|
889 |
-
|
890 |
-
model.eval()
|
891 |
-
self.model = model
|
892 |
-
|
893 |
-
@torch.no_grad()
|
894 |
-
def __call__(self, audio, sr, threshold=0.05, f0_min=50, f0_max=1100, p_len=None):
|
895 |
-
if self.onnx: self.wav2mel = Wav2Mel(device=self.device, dtype=self.dtype)
|
896 |
-
return (torch.as_tensor(self.model.run([self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype).detach().cpu().numpy(), self.model.get_inputs()[1].name: np.array(threshold, dtype=np.float32)})[0], dtype=self.dtype, device=self.device) if self.onnx else self.model.infer(audio[None, :], sr, threshold=threshold, f0_min=f0_min, f0_max=f0_max, output_interp_target_length=p_len))
|
897 |
-
|
898 |
-
class MelModule(torch.nn.Module):
|
899 |
-
def __init__(self, sr, n_mels, n_fft, win_size, hop_length, fmin = None, fmax = None, clip_val = 1e-5, out_stft = False):
|
900 |
-
super().__init__()
|
901 |
-
if fmin is None: fmin = 0
|
902 |
-
if fmax is None: fmax = sr / 2
|
903 |
-
|
904 |
-
self.target_sr = sr
|
905 |
-
self.n_mels = n_mels
|
906 |
-
self.n_fft = n_fft
|
907 |
-
self.win_size = win_size
|
908 |
-
self.hop_length = hop_length
|
909 |
-
self.fmin = fmin
|
910 |
-
self.fmax = fmax
|
911 |
-
self.clip_val = clip_val
|
912 |
-
|
913 |
-
self.register_buffer('mel_basis', torch.tensor(librosa_mel_fn(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)).float(), persistent=False)
|
914 |
-
self.hann_window = torch.nn.ModuleDict()
|
915 |
-
self.out_stft = out_stft
|
916 |
-
|
917 |
-
@torch.no_grad()
|
918 |
-
def __call__(self, y, key_shift = 0, speed = 1, center = False, no_cache_window = False):
|
919 |
-
n_fft = self.n_fft
|
920 |
-
win_size = self.win_size
|
921 |
-
hop_length = self.hop_length
|
922 |
-
clip_val = self.clip_val
|
923 |
-
|
924 |
-
factor = 2 ** (key_shift / 12)
|
925 |
-
n_fft_new = int(np.round(n_fft * factor))
|
926 |
-
win_size_new = int(np.round(win_size * factor))
|
927 |
-
hop_length_new = int(np.round(hop_length * speed))
|
928 |
-
|
929 |
-
y = y.squeeze(-1)
|
930 |
-
|
931 |
-
if torch.min(y) < -1: print('[error with torchfcpe.mel_extractor.MelModule] min ', torch.min(y))
|
932 |
-
if torch.max(y) > 1: print('[error with torchfcpe.mel_extractor.MelModule] max ', torch.max(y))
|
933 |
-
|
934 |
-
key_shift_key = str(key_shift)
|
935 |
-
if not no_cache_window:
|
936 |
-
if key_shift_key in self.hann_window: hann_window = self.hann_window[key_shift_key]
|
937 |
-
else:
|
938 |
-
hann_window = HannWindow(win_size_new).to(self.mel_basis.device)
|
939 |
-
self.hann_window[key_shift_key] = hann_window
|
940 |
-
|
941 |
-
hann_window_tensor = hann_window()
|
942 |
-
else: hann_window_tensor = torch.hann_window(win_size_new).to(self.mel_basis.device)
|
943 |
-
|
944 |
-
pad_left = (win_size_new - hop_length_new) // 2
|
945 |
-
pad_right = max((win_size_new - hop_length_new + 1) // 2, win_size_new - y.size(-1) - pad_left)
|
946 |
-
|
947 |
-
mode = 'reflect' if pad_right < y.size(-1) else 'constant'
|
948 |
-
|
949 |
-
spec = torch.stft(F.pad(y.unsqueeze(1), (pad_left, pad_right), mode=mode).squeeze(1), n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=hann_window_tensor, center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
|
950 |
-
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-9)
|
951 |
-
|
952 |
-
if key_shift != 0:
|
953 |
-
size = n_fft // 2 + 1
|
954 |
-
resize = spec.size(1)
|
955 |
-
|
956 |
-
if resize < size: spec = F.pad(spec, (0, 0, 0, size - resize))
|
957 |
-
spec = spec[:, :size, :] * win_size / win_size_new
|
958 |
-
|
959 |
-
spec = spec[:, :512, :] if self.out_stft else torch.matmul(self.mel_basis, spec)
|
960 |
-
|
961 |
-
return dynamic_range_compression_torch(spec, clip_val=clip_val).transpose(-1, -2)
|
962 |
-
|
963 |
-
class Wav2MelModule(torch.nn.Module):
|
964 |
-
def __init__(self, sr, n_mels, n_fft, win_size, hop_length, fmin = None, fmax = None, clip_val = 1e-5, mel_type="default"):
|
965 |
-
super().__init__()
|
966 |
-
if fmin is None: fmin = 0
|
967 |
-
if fmax is None: fmax = sr / 2
|
968 |
-
|
969 |
-
self.sampling_rate = sr
|
970 |
-
self.n_mels = n_mels
|
971 |
-
self.n_fft = n_fft
|
972 |
-
self.win_size = win_size
|
973 |
-
self.hop_size = hop_length
|
974 |
-
self.fmin = fmin
|
975 |
-
self.fmax = fmax
|
976 |
-
self.clip_val = clip_val
|
977 |
-
|
978 |
-
self.register_buffer('tensor_device_marker', torch.tensor(1.0).float(), persistent=False)
|
979 |
-
self.resample_kernel = torch.nn.ModuleDict()
|
980 |
-
|
981 |
-
if mel_type == "default": self.mel_extractor = MelModule(sr, n_mels, n_fft, win_size, hop_length, fmin, fmax, clip_val, out_stft=False)
|
982 |
-
elif mel_type == "stft": self.mel_extractor = MelModule(sr, n_mels, n_fft, win_size, hop_length, fmin, fmax, clip_val, out_stft=True)
|
983 |
-
|
984 |
-
self.mel_type = mel_type
|
985 |
-
|
986 |
-
@torch.no_grad()
|
987 |
-
def __call__(self, audio, sample_rate, keyshift = 0, no_cache_window = False):
|
988 |
-
|
989 |
-
if sample_rate == self.sampling_rate: audio_res = audio
|
990 |
-
else:
|
991 |
-
key_str = str(sample_rate)
|
992 |
-
|
993 |
-
if key_str not in self.resample_kernel:
|
994 |
-
if len(self.resample_kernel) > 8: self.resample_kernel.clear()
|
995 |
-
self.resample_kernel[key_str] = Resample(sample_rate, self.sampling_rate, lowpass_filter_width=128).to(self.tensor_device_marker.device)
|
996 |
-
|
997 |
-
audio_res = self.resample_kernel[key_str](audio.squeeze(-1)).unsqueeze(-1)
|
998 |
-
|
999 |
-
mel = self.mel_extractor(audio_res, keyshift, no_cache_window=no_cache_window)
|
1000 |
-
n_frames = int(audio.shape[1] // self.hop_size) + 1
|
1001 |
-
|
1002 |
-
if n_frames > int(mel.shape[1]): mel = torch.cat((mel, mel[:, -1:, :]), 1)
|
1003 |
-
if n_frames < int(mel.shape[1]): mel = mel[:, :n_frames, :]
|
1004 |
-
|
1005 |
-
return mel
|
1006 |
-
|
1007 |
-
class Wav2Mel:
|
1008 |
-
def __init__(self, device=None, dtype=torch.float32):
|
1009 |
-
self.sample_rate = 16000
|
1010 |
-
self.hop_size = 160
|
1011 |
-
if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
|
1012 |
-
self.device = device
|
1013 |
-
self.dtype = dtype
|
1014 |
-
self.stft = STFT(16000, 128, 1024, 1024, 160, 0, 8000)
|
1015 |
-
self.resample_kernel = {}
|
1016 |
-
|
1017 |
-
def extract_nvstft(self, audio, keyshift=0, train=False):
|
1018 |
-
return self.stft.get_mel(audio, keyshift=keyshift, train=train).transpose(1, 2)
|
1019 |
-
|
1020 |
-
def extract_mel(self, audio, sample_rate, keyshift=0, train=False):
|
1021 |
-
audio = audio.to(self.dtype).to(self.device)
|
1022 |
-
|
1023 |
-
if sample_rate == self.sample_rate: audio_res = audio
|
1024 |
-
else:
|
1025 |
-
key_str = str(sample_rate)
|
1026 |
-
if key_str not in self.resample_kernel: self.resample_kernel[key_str] = Resample(sample_rate, self.sample_rate, lowpass_filter_width=128)
|
1027 |
-
|
1028 |
-
self.resample_kernel[key_str] = (self.resample_kernel[key_str].to(self.dtype).to(self.device))
|
1029 |
-
audio_res = self.resample_kernel[key_str](audio)
|
1030 |
-
|
1031 |
-
mel = self.extract_nvstft(audio_res, keyshift=keyshift, train=train)
|
1032 |
-
n_frames = int(audio.shape[1] // self.hop_size) + 1
|
1033 |
-
|
1034 |
-
mel = (torch.cat((mel, mel[:, -1:, :]), 1) if n_frames > int(mel.shape[1]) else mel)
|
1035 |
-
return mel[:, :n_frames, :] if n_frames < int(mel.shape[1]) else mel
|
1036 |
-
|
1037 |
-
def __call__(self, audio, sample_rate, keyshift=0, train=False):
|
1038 |
-
return self.extract_mel(audio, sample_rate, keyshift=keyshift, train=train)
|
1039 |
-
|
1040 |
-
class DotDict(dict):
|
1041 |
-
def __getattr__(*args):
|
1042 |
-
val = dict.get(*args)
|
1043 |
-
return DotDict(val) if type(val) is dict else val
|
1044 |
-
|
1045 |
-
__setattr__ = dict.__setitem__
|
1046 |
-
__delattr__ = dict.__delitem__
|
1047 |
-
|
1048 |
-
class FCPE:
|
1049 |
-
def __init__(self, model_path, hop_length=512, f0_min=50, f0_max=1100, dtype=torch.float32, device=None, sample_rate=44100, threshold=0.05, providers=None, onnx=False, legacy=False):
|
1050 |
-
self.fcpe = FCPEInfer_LEGACY(model_path, device=device, dtype=dtype, providers=providers, onnx=onnx) if legacy else FCPEInfer(model_path, device=device, dtype=dtype, providers=providers, onnx=onnx)
|
1051 |
-
self.hop_length = hop_length
|
1052 |
-
self.f0_min = f0_min
|
1053 |
-
self.f0_max = f0_max
|
1054 |
-
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
1055 |
-
self.threshold = threshold
|
1056 |
-
self.sample_rate = sample_rate
|
1057 |
-
self.dtype = dtype
|
1058 |
-
self.legacy = legacy
|
1059 |
-
self.name = "fcpe"
|
1060 |
-
|
1061 |
-
def repeat_expand(self, content, target_len, mode = "nearest"):
|
1062 |
-
ndim = content.ndim
|
1063 |
-
content = (content[None, None] if ndim == 1 else content[None] if ndim == 2 else content)
|
1064 |
-
|
1065 |
-
assert content.ndim == 3
|
1066 |
-
is_np = isinstance(content, np.ndarray)
|
1067 |
-
|
1068 |
-
results = F.interpolate(torch.from_numpy(content) if is_np else content, size=target_len, mode=mode)
|
1069 |
-
results = results.numpy() if is_np else results
|
1070 |
-
return results[0, 0] if ndim == 1 else results[0] if ndim == 2 else results
|
1071 |
-
|
1072 |
-
def post_process(self, x, sample_rate, f0, pad_to):
|
1073 |
-
f0 = (torch.from_numpy(f0).float().to(x.device) if isinstance(f0, np.ndarray) else f0)
|
1074 |
-
f0 = self.repeat_expand(f0, pad_to) if pad_to is not None else f0
|
1075 |
-
|
1076 |
-
vuv_vector = torch.zeros_like(f0)
|
1077 |
-
vuv_vector[f0 > 0.0] = 1.0
|
1078 |
-
vuv_vector[f0 <= 0.0] = 0.0
|
1079 |
-
|
1080 |
-
nzindex = torch.nonzero(f0).squeeze()
|
1081 |
-
f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy()
|
1082 |
-
vuv_vector = F.interpolate(vuv_vector[None, None, :], size=pad_to)[0][0]
|
1083 |
-
|
1084 |
-
if f0.shape[0] <= 0: return np.zeros(pad_to), vuv_vector.cpu().numpy()
|
1085 |
-
if f0.shape[0] == 1: return np.ones(pad_to) * f0[0], vuv_vector.cpu().numpy()
|
1086 |
-
|
1087 |
-
return np.interp(np.arange(pad_to) * self.hop_length / sample_rate, self.hop_length / sample_rate * nzindex.cpu().numpy(), f0, left=f0[0], right=f0[-1]), vuv_vector.cpu().numpy()
|
1088 |
-
|
1089 |
-
def compute_f0(self, wav, p_len=None):
|
1090 |
-
x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
|
1091 |
-
p_len = x.shape[0] // self.hop_length if p_len is None else p_len
|
1092 |
-
|
1093 |
-
f0 = self.fcpe(x, sr=self.sample_rate, threshold=self.threshold) if self.legacy else (self.fcpe(x, sr=self.sample_rate, threshold=self.threshold, f0_min=self.f0_min, f0_max=self.f0_max, p_len=p_len))
|
1094 |
-
f0 = f0[:] if f0.dim() == 1 else f0[0, :, 0]
|
1095 |
-
|
1096 |
-
if torch.all(f0 == 0): return f0.cpu().numpy() if p_len is None else np.zeros(p_len), (f0.cpu().numpy() if p_len is None else np.zeros(p_len))
|
1097 |
-
return self.post_process(x, self.sample_rate, f0, p_len)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/predictors/RMVPE.py
DELETED
@@ -1,260 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
import torch.nn as nn
|
5 |
-
import torch.nn.functional as F
|
6 |
-
|
7 |
-
from librosa.filters import mel
|
8 |
-
|
9 |
-
N_MELS, N_CLASS = 128, 360
|
10 |
-
|
11 |
-
class ConvBlockRes(nn.Module):
|
12 |
-
def __init__(self, in_channels, out_channels, momentum=0.01):
|
13 |
-
super(ConvBlockRes, self).__init__()
|
14 |
-
self.conv = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(), nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU())
|
15 |
-
|
16 |
-
if in_channels != out_channels:
|
17 |
-
self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
|
18 |
-
self.is_shortcut = True
|
19 |
-
else: self.is_shortcut = False
|
20 |
-
|
21 |
-
def forward(self, x):
|
22 |
-
return self.conv(x) + self.shortcut(x) if self.is_shortcut else self.conv(x) + x
|
23 |
-
|
24 |
-
class ResEncoderBlock(nn.Module):
|
25 |
-
def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
|
26 |
-
super(ResEncoderBlock, self).__init__()
|
27 |
-
self.n_blocks = n_blocks
|
28 |
-
self.conv = nn.ModuleList()
|
29 |
-
self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
|
30 |
-
|
31 |
-
for _ in range(n_blocks - 1):
|
32 |
-
self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
|
33 |
-
|
34 |
-
self.kernel_size = kernel_size
|
35 |
-
if self.kernel_size is not None: self.pool = nn.AvgPool2d(kernel_size=kernel_size)
|
36 |
-
|
37 |
-
def forward(self, x):
|
38 |
-
for i in range(self.n_blocks):
|
39 |
-
x = self.conv[i](x)
|
40 |
-
|
41 |
-
if self.kernel_size is not None: return x, self.pool(x)
|
42 |
-
else: return x
|
43 |
-
|
44 |
-
class Encoder(nn.Module):
|
45 |
-
def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
|
46 |
-
super(Encoder, self).__init__()
|
47 |
-
self.n_encoders = n_encoders
|
48 |
-
self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
|
49 |
-
self.layers = nn.ModuleList()
|
50 |
-
self.latent_channels = []
|
51 |
-
|
52 |
-
for _ in range(self.n_encoders):
|
53 |
-
self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum))
|
54 |
-
self.latent_channels.append([out_channels, in_size])
|
55 |
-
in_channels = out_channels
|
56 |
-
out_channels *= 2
|
57 |
-
in_size //= 2
|
58 |
-
|
59 |
-
self.out_size = in_size
|
60 |
-
self.out_channel = out_channels
|
61 |
-
|
62 |
-
def forward(self, x):
|
63 |
-
concat_tensors = []
|
64 |
-
x = self.bn(x)
|
65 |
-
|
66 |
-
for i in range(self.n_encoders):
|
67 |
-
t, x = self.layers[i](x)
|
68 |
-
concat_tensors.append(t)
|
69 |
-
|
70 |
-
return x, concat_tensors
|
71 |
-
|
72 |
-
class Intermediate(nn.Module):
|
73 |
-
def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
|
74 |
-
super(Intermediate, self).__init__()
|
75 |
-
self.n_inters = n_inters
|
76 |
-
self.layers = nn.ModuleList()
|
77 |
-
self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum))
|
78 |
-
|
79 |
-
for _ in range(self.n_inters - 1):
|
80 |
-
self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
|
81 |
-
|
82 |
-
def forward(self, x):
|
83 |
-
for i in range(self.n_inters):
|
84 |
-
x = self.layers[i](x)
|
85 |
-
|
86 |
-
return x
|
87 |
-
|
88 |
-
class ResDecoderBlock(nn.Module):
|
89 |
-
def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
|
90 |
-
super(ResDecoderBlock, self).__init__()
|
91 |
-
out_padding = (0, 1) if stride == (1, 2) else (1, 1)
|
92 |
-
self.n_blocks = n_blocks
|
93 |
-
self.conv1 = nn.Sequential(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=stride, padding=(1, 1), output_padding=out_padding, bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU())
|
94 |
-
self.conv2 = nn.ModuleList()
|
95 |
-
self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
|
96 |
-
|
97 |
-
for _ in range(n_blocks - 1):
|
98 |
-
self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
|
99 |
-
|
100 |
-
def forward(self, x, concat_tensor):
|
101 |
-
x = torch.cat((self.conv1(x), concat_tensor), dim=1)
|
102 |
-
|
103 |
-
for i in range(self.n_blocks):
|
104 |
-
x = self.conv2[i](x)
|
105 |
-
|
106 |
-
return x
|
107 |
-
|
108 |
-
class Decoder(nn.Module):
|
109 |
-
def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
|
110 |
-
super(Decoder, self).__init__()
|
111 |
-
self.layers = nn.ModuleList()
|
112 |
-
self.n_decoders = n_decoders
|
113 |
-
|
114 |
-
for _ in range(self.n_decoders):
|
115 |
-
out_channels = in_channels // 2
|
116 |
-
self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum))
|
117 |
-
in_channels = out_channels
|
118 |
-
|
119 |
-
def forward(self, x, concat_tensors):
|
120 |
-
for i in range(self.n_decoders):
|
121 |
-
x = self.layers[i](x, concat_tensors[-1 - i])
|
122 |
-
|
123 |
-
return x
|
124 |
-
|
125 |
-
class DeepUnet(nn.Module):
|
126 |
-
def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
|
127 |
-
super(DeepUnet, self).__init__()
|
128 |
-
self.encoder = Encoder(in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels)
|
129 |
-
self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
|
130 |
-
self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
|
131 |
-
|
132 |
-
def forward(self, x):
|
133 |
-
x, concat_tensors = self.encoder(x)
|
134 |
-
return self.decoder(self.intermediate(x), concat_tensors)
|
135 |
-
|
136 |
-
class E2E(nn.Module):
|
137 |
-
def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
|
138 |
-
super(E2E, self).__init__()
|
139 |
-
self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
|
140 |
-
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
|
141 |
-
self.fc = nn.Sequential(BiGRU(3 * 128, 256, n_gru), nn.Linear(512, N_CLASS), nn.Dropout(0.25), nn.Sigmoid()) if n_gru else nn.Sequential(nn.Linear(3 * N_MELS, N_CLASS), nn.Dropout(0.25), nn.Sigmoid())
|
142 |
-
|
143 |
-
def forward(self, mel):
|
144 |
-
return self.fc(self.cnn(self.unet(mel.transpose(-1, -2).unsqueeze(1))).transpose(1, 2).flatten(-2))
|
145 |
-
|
146 |
-
class MelSpectrogram(torch.nn.Module):
|
147 |
-
def __init__(self, is_half, n_mel_channels, sample_rate, win_length, hop_length, n_fft=None, mel_fmin=0, mel_fmax=None, clamp=1e-5):
|
148 |
-
super().__init__()
|
149 |
-
n_fft = win_length if n_fft is None else n_fft
|
150 |
-
self.hann_window = {}
|
151 |
-
mel_basis = mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax, htk=True)
|
152 |
-
mel_basis = torch.from_numpy(mel_basis).float()
|
153 |
-
self.register_buffer("mel_basis", mel_basis)
|
154 |
-
self.n_fft = win_length if n_fft is None else n_fft
|
155 |
-
self.hop_length = hop_length
|
156 |
-
self.win_length = win_length
|
157 |
-
self.sample_rate = sample_rate
|
158 |
-
self.n_mel_channels = n_mel_channels
|
159 |
-
self.clamp = clamp
|
160 |
-
self.is_half = is_half
|
161 |
-
|
162 |
-
def forward(self, audio, keyshift=0, speed=1, center=True):
|
163 |
-
factor = 2 ** (keyshift / 12)
|
164 |
-
win_length_new = int(np.round(self.win_length * factor))
|
165 |
-
keyshift_key = str(keyshift) + "_" + str(audio.device)
|
166 |
-
if keyshift_key not in self.hann_window: self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
|
167 |
-
|
168 |
-
fft = torch.stft(audio, n_fft=int(np.round(self.n_fft * factor)), hop_length=int(np.round(self.hop_length * speed)), win_length=win_length_new, window=self.hann_window[keyshift_key], center=center, return_complex=True)
|
169 |
-
magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
|
170 |
-
|
171 |
-
if keyshift != 0:
|
172 |
-
size = self.n_fft // 2 + 1
|
173 |
-
resize = magnitude.size(1)
|
174 |
-
|
175 |
-
if resize < size: magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
|
176 |
-
magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
|
177 |
-
|
178 |
-
mel_output = torch.matmul(self.mel_basis, magnitude)
|
179 |
-
if self.is_half: mel_output = mel_output.half()
|
180 |
-
|
181 |
-
return torch.log(torch.clamp(mel_output, min=self.clamp))
|
182 |
-
|
183 |
-
class RMVPE:
|
184 |
-
def __init__(self, model_path, is_half, device=None, providers=None, onnx=False):
|
185 |
-
self.resample_kernel = {}
|
186 |
-
self.onnx = onnx
|
187 |
-
|
188 |
-
if self.onnx:
|
189 |
-
import onnxruntime as ort
|
190 |
-
|
191 |
-
sess_options = ort.SessionOptions()
|
192 |
-
sess_options.log_severity_level = 3
|
193 |
-
|
194 |
-
self.model = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
|
195 |
-
else:
|
196 |
-
model = E2E(4, 1, (2, 2))
|
197 |
-
ckpt = torch.load(model_path, map_location="cpu")
|
198 |
-
model.load_state_dict(ckpt)
|
199 |
-
model.eval()
|
200 |
-
if is_half: model = model.half()
|
201 |
-
self.model = model.to(device)
|
202 |
-
|
203 |
-
self.resample_kernel = {}
|
204 |
-
self.is_half = is_half
|
205 |
-
self.device = device
|
206 |
-
self.mel_extractor = MelSpectrogram(is_half, N_MELS, 16000, 1024, 160, None, 30, 8000).to(device)
|
207 |
-
cents_mapping = 20 * np.arange(N_CLASS) + 1997.3794084376191
|
208 |
-
self.cents_mapping = np.pad(cents_mapping, (4, 4))
|
209 |
-
|
210 |
-
def mel2hidden(self, mel):
|
211 |
-
with torch.no_grad():
|
212 |
-
n_frames = mel.shape[-1]
|
213 |
-
mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect")
|
214 |
-
hidden = self.model.run([self.model.get_outputs()[0].name], input_feed={self.model.get_inputs()[0].name: mel.cpu().numpy().astype(np.float32)})[0] if self.onnx else self.model(mel.half() if self.is_half else mel.float())
|
215 |
-
return hidden[:, :n_frames]
|
216 |
-
|
217 |
-
def decode(self, hidden, thred=0.03):
|
218 |
-
f0 = 10 * (2 ** (self.to_local_average_cents(hidden, thred=thred) / 1200))
|
219 |
-
f0[f0 == 10] = 0
|
220 |
-
|
221 |
-
return f0
|
222 |
-
|
223 |
-
def infer_from_audio(self, audio, thred=0.03):
|
224 |
-
hidden = self.mel2hidden(self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True))
|
225 |
-
|
226 |
-
return self.decode((hidden.squeeze(0).cpu().numpy().astype(np.float32) if self.is_half else hidden.squeeze(0).cpu().numpy()) if not self.onnx else hidden[0], thred=thred)
|
227 |
-
|
228 |
-
def infer_from_audio_with_pitch(self, audio, thred=0.03, f0_min=50, f0_max=1100):
|
229 |
-
hidden = self.mel2hidden(self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True))
|
230 |
-
|
231 |
-
f0 = self.decode((hidden.squeeze(0).cpu().numpy().astype(np.float32) if self.is_half else hidden.squeeze(0).cpu().numpy()) if not self.onnx else hidden[0], thred=thred)
|
232 |
-
f0[(f0 < f0_min) | (f0 > f0_max)] = 0
|
233 |
-
|
234 |
-
return f0
|
235 |
-
|
236 |
-
def to_local_average_cents(self, salience, thred=0.05):
|
237 |
-
center = np.argmax(salience, axis=1)
|
238 |
-
salience = np.pad(salience, ((0, 0), (4, 4)))
|
239 |
-
center += 4
|
240 |
-
todo_salience, todo_cents_mapping = [], []
|
241 |
-
starts = center - 4
|
242 |
-
ends = center + 5
|
243 |
-
|
244 |
-
for idx in range(salience.shape[0]):
|
245 |
-
todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
|
246 |
-
todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
|
247 |
-
|
248 |
-
todo_salience = np.array(todo_salience)
|
249 |
-
devided = np.sum(todo_salience * np.array(todo_cents_mapping), 1) / np.sum(todo_salience, 1)
|
250 |
-
devided[np.max(salience, axis=1) <= thred] = 0
|
251 |
-
|
252 |
-
return devided
|
253 |
-
|
254 |
-
class BiGRU(nn.Module):
|
255 |
-
def __init__(self, input_features, hidden_features, num_layers):
|
256 |
-
super(BiGRU, self).__init__()
|
257 |
-
self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
|
258 |
-
|
259 |
-
def forward(self, x):
|
260 |
-
return self.gru(x)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/predictors/SWIPE.py
DELETED
@@ -1,140 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
|
5 |
-
from matplotlib import mlab
|
6 |
-
from scipy import interpolate
|
7 |
-
from decimal import Decimal, ROUND_HALF_UP
|
8 |
-
|
9 |
-
def swipe(x, fs, f0_floor=50, f0_ceil=1100, frame_period=10, sTHR=0.3):
|
10 |
-
plim = np.array([f0_floor, f0_ceil])
|
11 |
-
t = np.arange(0, int(1000 * len(x) / fs / (frame_period) + 1)) * (frame_period / 1000)
|
12 |
-
|
13 |
-
log2pc = np.arange(np.log2(plim[0]) * 96, np.log2(plim[-1]) * 96)
|
14 |
-
log2pc *= (1 / 96)
|
15 |
-
|
16 |
-
pc = 2 ** log2pc
|
17 |
-
S = np.zeros((len(pc), len(t)))
|
18 |
-
|
19 |
-
logWs = [round_matlab(elm) for elm in np.log2(4 * 2 * fs / plim)]
|
20 |
-
ws = 2 ** np.arange(logWs[0], logWs[1] - 1, -1)
|
21 |
-
p0 = 4 * 2 * fs / ws
|
22 |
-
|
23 |
-
d = 1 + log2pc - np.log2(4 * 2 * fs / ws[0])
|
24 |
-
fERBs = erbs2hz(np.arange(hz2erbs(pc[0] / 4), hz2erbs(fs / 2), 0.1))
|
25 |
-
|
26 |
-
for i in range(len(ws)):
|
27 |
-
dn = round_matlab(4 * fs / p0[i])
|
28 |
-
X, f, ti = mlab.specgram(x=np.r_[np.zeros(int(ws[i] / 2)), np.r_[x, np.zeros(int(dn + ws[i] / 2))]], NFFT=ws[i], Fs=fs, window=np.hanning(ws[i] + 2)[1:-1], noverlap=max(0, np.round(ws[i] - dn)), mode='complex')
|
29 |
-
ti = np.r_[0, ti[:-1]]
|
30 |
-
M = np.maximum(0, interpolate.interp1d(f, np.abs(X.T), kind='cubic')(fERBs)).T
|
31 |
-
|
32 |
-
if i == len(ws) - 1:
|
33 |
-
j = np.where(d - (i + 1) > -1)[0]
|
34 |
-
k = np.where(d[j] - (i + 1) < 0)[0]
|
35 |
-
elif i == 0:
|
36 |
-
j = np.where(d - (i + 1) < 1)[0]
|
37 |
-
k = np.where(d[j] - (i + 1) > 0)[0]
|
38 |
-
else:
|
39 |
-
j = np.where(np.abs(d - (i + 1)) < 1)[0]
|
40 |
-
k = np.arange(len(j))
|
41 |
-
|
42 |
-
Si = pitchStrengthAllCandidates(fERBs, np.sqrt(M), pc[j])
|
43 |
-
Si = interpolate.interp1d(ti, Si, bounds_error=False, fill_value='nan')(t) if Si.shape[1] > 1 else np.full((len(Si), len(t)), np.nan)
|
44 |
-
|
45 |
-
mu = np.ones(j.shape)
|
46 |
-
mu[k] = 1 - np.abs(d[j[k]] - i - 1)
|
47 |
-
S[j, :] = S[j, :] + np.tile(mu.reshape(-1, 1), (1, Si.shape[1])) * Si
|
48 |
-
|
49 |
-
|
50 |
-
p = np.full((S.shape[1], 1), np.nan)
|
51 |
-
s = np.full((S.shape[1], 1), np.nan)
|
52 |
-
|
53 |
-
for j in range(S.shape[1]):
|
54 |
-
s[j] = np.max(S[:, j])
|
55 |
-
i = np.argmax(S[:, j])
|
56 |
-
|
57 |
-
if s[j] < sTHR: continue
|
58 |
-
|
59 |
-
if i == 0: p[j] = pc[0]
|
60 |
-
elif i == len(pc) - 1: p[j] = pc[0]
|
61 |
-
else:
|
62 |
-
I = np.arange(i-1, i+2)
|
63 |
-
tc = 1 / pc[I]
|
64 |
-
|
65 |
-
ntc = (tc / tc[1] - 1) * 2 * np.pi
|
66 |
-
idx = np.isfinite(S[I, j])
|
67 |
-
|
68 |
-
c = np.zeros(len(ntc))
|
69 |
-
c += np.nan
|
70 |
-
|
71 |
-
I_ = I[idx]
|
72 |
-
|
73 |
-
if len(I_) < 2: c[idx] = (S[I, j])[0] / ntc[0]
|
74 |
-
else: c[idx] = np.polyfit(ntc[idx], (S[I_, j]), 2)
|
75 |
-
|
76 |
-
pval = np.polyval(c, ((1 / (2 ** np.arange(np.log2(pc[I[0]]), np.log2(pc[I[2]]) + 1 / 12 / 64, 1 / 12 / 64))) / tc[1] - 1) * 2 * np.pi)
|
77 |
-
s[j] = np.max(pval)
|
78 |
-
p[j] = 2 ** (np.log2(pc[I[0]]) + (np.argmax(pval)) / 12 / 64)
|
79 |
-
|
80 |
-
p = p.flatten()
|
81 |
-
p[np.isnan(p)] = 0
|
82 |
-
|
83 |
-
return np.array(p, dtype=np.float32), np.array(t, dtype=np.float32)
|
84 |
-
|
85 |
-
def round_matlab(n):
|
86 |
-
return int(Decimal(n).quantize(0, ROUND_HALF_UP))
|
87 |
-
|
88 |
-
def pitchStrengthAllCandidates(f, L, pc):
|
89 |
-
den = np.sqrt(np.sum(L * L, axis=0))
|
90 |
-
den = np.where(den == 0, 2.220446049250313e-16, den)
|
91 |
-
|
92 |
-
L = L / den
|
93 |
-
S = np.zeros((len(pc), L.shape[1]))
|
94 |
-
|
95 |
-
for j in range(len(pc)):
|
96 |
-
S[j,:] = pitchStrengthOneCandidate(f, L, pc[j])
|
97 |
-
|
98 |
-
return S
|
99 |
-
|
100 |
-
def pitchStrengthOneCandidate(f, L, pc):
|
101 |
-
k = np.zeros(len(f))
|
102 |
-
q = f / pc
|
103 |
-
|
104 |
-
for i in ([1] + sieve(int(np.fix(f[-1] / pc - 0.75)))):
|
105 |
-
a = np.abs(q - i)
|
106 |
-
p = a < 0.25
|
107 |
-
k[p] = np.cos(2 * np.pi * q[p])
|
108 |
-
|
109 |
-
v = np.logical_and((0.25 < a), (a < 0.75))
|
110 |
-
k[v] = k[v] + np.cos(2 * np.pi * q[v]) / 2
|
111 |
-
|
112 |
-
k *= np.sqrt(1 / f)
|
113 |
-
k /= np.linalg.norm(k[k>0])
|
114 |
-
|
115 |
-
return k @ L
|
116 |
-
|
117 |
-
def hz2erbs(hz):
|
118 |
-
return 21.4 * np.log10(1 + hz / 229)
|
119 |
-
|
120 |
-
def erbs2hz(erbs):
|
121 |
-
return (10 ** (erbs / 21.4) - 1) * 229
|
122 |
-
|
123 |
-
def sieve(n):
|
124 |
-
primes = list(range(2, n + 1))
|
125 |
-
num = 2
|
126 |
-
|
127 |
-
while num < math.sqrt(n):
|
128 |
-
i = num
|
129 |
-
|
130 |
-
while i <= n:
|
131 |
-
i += num
|
132 |
-
|
133 |
-
if i in primes: primes.remove(i)
|
134 |
-
|
135 |
-
for j in primes:
|
136 |
-
if j > num:
|
137 |
-
num = j
|
138 |
-
break
|
139 |
-
|
140 |
-
return primes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/predictors/WORLD_WRAPPER.py
DELETED
@@ -1,90 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import torch
|
3 |
-
import ctypes
|
4 |
-
import platform
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
class DioOption(ctypes.Structure):
|
11 |
-
_fields_ = [("F0Floor", ctypes.c_double), ("F0Ceil", ctypes.c_double), ("ChannelsInOctave", ctypes.c_double), ("FramePeriod", ctypes.c_double), ("Speed", ctypes.c_int), ("AllowedRange", ctypes.c_double)]
|
12 |
-
|
13 |
-
class HarvestOption(ctypes.Structure):
|
14 |
-
_fields_ = [("F0Floor", ctypes.c_double), ("F0Ceil", ctypes.c_double), ("FramePeriod", ctypes.c_double)]
|
15 |
-
|
16 |
-
class PYWORLD:
|
17 |
-
def __init__(self):
|
18 |
-
self.world_path = os.path.join("assets", "models", "predictors", "world")
|
19 |
-
os.makedirs(self.world_path, exist_ok=True)
|
20 |
-
|
21 |
-
model_type, suffix = (("world_64" if platform.architecture()[0] == "64bit" else "world_86"), ".dll") if platform.system() == "Windows" else ("world_linux", ".so")
|
22 |
-
self.world_file_path = os.path.join(self.world_path, f"{model_type}{suffix}")
|
23 |
-
|
24 |
-
if not os.path.exists(self.world_file_path):
|
25 |
-
model = torch.load(os.path.join("assets", "models", "predictors", "world.pth"), map_location="cpu")
|
26 |
-
|
27 |
-
with open(self.world_file_path, "wb") as w:
|
28 |
-
w.write(model[model_type])
|
29 |
-
|
30 |
-
self.world_dll = ctypes.CDLL(self.world_file_path)
|
31 |
-
|
32 |
-
def harvest(self, x, fs, f0_floor=50, f0_ceil=1100, frame_period=10):
|
33 |
-
self.world_dll.Harvest.argtypes = [ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.c_int, ctypes.POINTER(HarvestOption), ctypes.POINTER(ctypes.c_double), ctypes.POINTER(ctypes.c_double)]
|
34 |
-
self.world_dll.Harvest.restype = None
|
35 |
-
|
36 |
-
self.world_dll.InitializeHarvestOption.argtypes = [ctypes.POINTER(HarvestOption)]
|
37 |
-
self.world_dll.InitializeHarvestOption.restype = None
|
38 |
-
|
39 |
-
self.world_dll.GetSamplesForHarvest.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_double]
|
40 |
-
self.world_dll.GetSamplesForHarvest.restype = ctypes.c_int
|
41 |
-
|
42 |
-
option = HarvestOption()
|
43 |
-
self.world_dll.InitializeHarvestOption(ctypes.byref(option))
|
44 |
-
|
45 |
-
option.F0Floor = f0_floor
|
46 |
-
option.F0Ceil = f0_ceil
|
47 |
-
option.FramePeriod = frame_period
|
48 |
-
|
49 |
-
f0_length = self.world_dll.GetSamplesForHarvest(fs, len(x), option.FramePeriod)
|
50 |
-
f0 = (ctypes.c_double * f0_length)()
|
51 |
-
tpos = (ctypes.c_double * f0_length)()
|
52 |
-
|
53 |
-
self.world_dll.Harvest((ctypes.c_double * len(x))(*x), len(x), fs, ctypes.byref(option), tpos, f0)
|
54 |
-
return np.array(f0, dtype=np.float32), np.array(tpos, dtype=np.float32)
|
55 |
-
|
56 |
-
def dio(self, x, fs, f0_floor=50, f0_ceil=1100, channels_in_octave=2, frame_period=10, speed=1, allowed_range=0.1):
|
57 |
-
self.world_dll.Dio.argtypes = [ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.c_int, ctypes.POINTER(DioOption), ctypes.POINTER(ctypes.c_double), ctypes.POINTER(ctypes.c_double)]
|
58 |
-
self.world_dll.Dio.restype = None
|
59 |
-
|
60 |
-
self.world_dll.InitializeDioOption.argtypes = [ctypes.POINTER(DioOption)]
|
61 |
-
self.world_dll.InitializeDioOption.restype = None
|
62 |
-
|
63 |
-
self.world_dll.GetSamplesForDIO.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_double]
|
64 |
-
self.world_dll.GetSamplesForDIO.restype = ctypes.c_int
|
65 |
-
|
66 |
-
option = DioOption()
|
67 |
-
self.world_dll.InitializeDioOption(ctypes.byref(option))
|
68 |
-
|
69 |
-
option.F0Floor = f0_floor
|
70 |
-
option.F0Ceil = f0_ceil
|
71 |
-
option.ChannelsInOctave = channels_in_octave
|
72 |
-
option.FramePeriod = frame_period
|
73 |
-
option.Speed = speed
|
74 |
-
option.AllowedRange = allowed_range
|
75 |
-
|
76 |
-
f0_length = self.world_dll.GetSamplesForDIO(fs, len(x), option.FramePeriod)
|
77 |
-
f0 = (ctypes.c_double * f0_length)()
|
78 |
-
tpos = (ctypes.c_double * f0_length)()
|
79 |
-
|
80 |
-
self.world_dll.Dio((ctypes.c_double * len(x))(*x), len(x), fs, ctypes.byref(option), tpos, f0)
|
81 |
-
return np.array(f0, dtype=np.float32), np.array(tpos, dtype=np.float32)
|
82 |
-
|
83 |
-
def stonemask(self, x, fs, tpos, f0):
|
84 |
-
self.world_dll.StoneMask.argtypes = [ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_double), ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.POINTER(ctypes.c_double)]
|
85 |
-
self.world_dll.StoneMask.restype = None
|
86 |
-
|
87 |
-
out_f0 = (ctypes.c_double * len(f0))()
|
88 |
-
self.world_dll.StoneMask((ctypes.c_double * len(x))(*x), len(x), fs, (ctypes.c_double * len(tpos))(*tpos), (ctypes.c_double * len(f0))(*f0), len(f0), out_f0)
|
89 |
-
|
90 |
-
return np.array(out_f0, dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/speaker_diarization/ECAPA_TDNN.py
DELETED
@@ -1,280 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
import torch
|
3 |
-
|
4 |
-
import torch.nn as nn
|
5 |
-
import torch.nn.functional as F
|
6 |
-
|
7 |
-
def length_to_mask(length, max_len=None, dtype=None, device=None):
|
8 |
-
assert len(length.shape) == 1
|
9 |
-
|
10 |
-
if max_len is None: max_len = length.max().long().item()
|
11 |
-
|
12 |
-
mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand(len(length), max_len) < length.unsqueeze(1)
|
13 |
-
|
14 |
-
if dtype is None: dtype = length.dtype
|
15 |
-
if device is None: device = length.device
|
16 |
-
|
17 |
-
return torch.as_tensor(mask, dtype=dtype, device=device)
|
18 |
-
|
19 |
-
def get_padding_elem(L_in, stride, kernel_size, dilation):
|
20 |
-
if stride > 1: padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)]
|
21 |
-
else:
|
22 |
-
L_out = (math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1)
|
23 |
-
padding = [math.floor((L_in - L_out) / 2), math.floor((L_in - L_out) / 2)]
|
24 |
-
|
25 |
-
return padding
|
26 |
-
|
27 |
-
class _BatchNorm1d(nn.Module):
|
28 |
-
def __init__(self, input_shape=None, input_size=None, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, combine_batch_time=False, skip_transpose=False):
|
29 |
-
super().__init__()
|
30 |
-
self.combine_batch_time = combine_batch_time
|
31 |
-
self.skip_transpose = skip_transpose
|
32 |
-
|
33 |
-
if input_size is None and skip_transpose: input_size = input_shape[1]
|
34 |
-
elif input_size is None: input_size = input_shape[-1]
|
35 |
-
|
36 |
-
self.norm = nn.BatchNorm1d(input_size, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
|
37 |
-
|
38 |
-
def forward(self, x):
|
39 |
-
shape_or = x.shape
|
40 |
-
|
41 |
-
if self.combine_batch_time:x = x.reshape(shape_or[0] * shape_or[1], shape_or[2]) if x.ndim == 3 else x.reshape(shape_or[0] * shape_or[1], shape_or[3], shape_or[2])
|
42 |
-
elif not self.skip_transpose: x = x.transpose(-1, 1)
|
43 |
-
|
44 |
-
x_n = self.norm(x)
|
45 |
-
|
46 |
-
if self.combine_batch_time: x_n = x_n.reshape(shape_or)
|
47 |
-
elif not self.skip_transpose: x_n = x_n.transpose(1, -1)
|
48 |
-
|
49 |
-
return x_n
|
50 |
-
|
51 |
-
class _Conv1d(nn.Module):
|
52 |
-
def __init__(self, out_channels, kernel_size, input_shape=None, in_channels=None, stride=1, dilation=1, padding="same", groups=1, bias=True, padding_mode="reflect", skip_transpose=False, weight_norm=False, conv_init=None, default_padding=0):
|
53 |
-
super().__init__()
|
54 |
-
self.kernel_size = kernel_size
|
55 |
-
self.stride = stride
|
56 |
-
self.dilation = dilation
|
57 |
-
self.padding = padding
|
58 |
-
self.padding_mode = padding_mode
|
59 |
-
self.unsqueeze = False
|
60 |
-
self.skip_transpose = skip_transpose
|
61 |
-
|
62 |
-
if input_shape is None and in_channels is None: raise ValueError
|
63 |
-
if in_channels is None: in_channels = self._check_input_shape(input_shape)
|
64 |
-
|
65 |
-
self.in_channels = in_channels
|
66 |
-
self.conv = nn.Conv1d(in_channels, out_channels, self.kernel_size, stride=self.stride, dilation=self.dilation, padding=default_padding, groups=groups, bias=bias)
|
67 |
-
|
68 |
-
if conv_init == "kaiming": nn.init.kaiming_normal_(self.conv.weight)
|
69 |
-
elif conv_init == "zero": nn.init.zeros_(self.conv.weight)
|
70 |
-
elif conv_init == "normal": nn.init.normal_(self.conv.weight, std=1e-6)
|
71 |
-
|
72 |
-
if weight_norm: self.conv = nn.utils.weight_norm(self.conv)
|
73 |
-
|
74 |
-
def forward(self, x):
|
75 |
-
if not self.skip_transpose: x = x.transpose(1, -1)
|
76 |
-
if self.unsqueeze: x = x.unsqueeze(1)
|
77 |
-
|
78 |
-
if self.padding == "same": x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride)
|
79 |
-
elif self.padding == "causal": x = F.pad(x, ((self.kernel_size - 1) * self.dilation, 0))
|
80 |
-
elif self.padding == "valid": pass
|
81 |
-
else: raise ValueError
|
82 |
-
|
83 |
-
wx = self.conv(x)
|
84 |
-
|
85 |
-
if self.unsqueeze: wx = wx.squeeze(1)
|
86 |
-
if not self.skip_transpose: wx = wx.transpose(1, -1)
|
87 |
-
|
88 |
-
return wx
|
89 |
-
|
90 |
-
def _manage_padding(self, x, kernel_size, dilation, stride):
|
91 |
-
return F.pad(x, get_padding_elem(self.in_channels, stride, kernel_size, dilation), mode=self.padding_mode)
|
92 |
-
|
93 |
-
def _check_input_shape(self, shape):
|
94 |
-
if len(shape) == 2:
|
95 |
-
self.unsqueeze = True
|
96 |
-
in_channels = 1
|
97 |
-
elif self.skip_transpose: in_channels = shape[1]
|
98 |
-
elif len(shape) == 3: in_channels = shape[2]
|
99 |
-
else: raise ValueError
|
100 |
-
|
101 |
-
if not self.padding == "valid" and self.kernel_size % 2 == 0: raise ValueError
|
102 |
-
return in_channels
|
103 |
-
|
104 |
-
def remove_weight_norm(self):
|
105 |
-
self.conv = nn.utils.remove_weight_norm(self.conv)
|
106 |
-
|
107 |
-
class Linear(torch.nn.Module):
|
108 |
-
def __init__(self, n_neurons, input_shape=None, input_size=None, bias=True, max_norm=None, combine_dims=False):
|
109 |
-
super().__init__()
|
110 |
-
self.max_norm = max_norm
|
111 |
-
self.combine_dims = combine_dims
|
112 |
-
|
113 |
-
if input_shape is None and input_size is None: raise ValueError
|
114 |
-
if input_size is None:
|
115 |
-
input_size = input_shape[-1]
|
116 |
-
if len(input_shape) == 4 and self.combine_dims: input_size = input_shape[2] * input_shape[3]
|
117 |
-
|
118 |
-
self.w = nn.Linear(input_size, n_neurons, bias=bias)
|
119 |
-
|
120 |
-
def forward(self, x):
|
121 |
-
if x.ndim == 4 and self.combine_dims: x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
|
122 |
-
if self.max_norm is not None: self.w.weight.data = torch.renorm(self.w.weight.data, p=2, dim=0, maxnorm=self.max_norm)
|
123 |
-
|
124 |
-
return self.w(x)
|
125 |
-
|
126 |
-
class Conv1d(_Conv1d):
|
127 |
-
def __init__(self, *args, **kwargs):
|
128 |
-
super().__init__(skip_transpose=True, *args, **kwargs)
|
129 |
-
|
130 |
-
class BatchNorm1d(_BatchNorm1d):
|
131 |
-
def __init__(self, *args, **kwargs):
|
132 |
-
super().__init__(skip_transpose=True, *args, **kwargs)
|
133 |
-
|
134 |
-
class TDNNBlock(nn.Module):
|
135 |
-
def __init__(self, in_channels, out_channels, kernel_size, dilation, activation=nn.ReLU, groups=1, dropout=0.0):
|
136 |
-
super().__init__()
|
137 |
-
self.conv = Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, dilation=dilation, groups=groups)
|
138 |
-
self.activation = activation()
|
139 |
-
self.norm = BatchNorm1d(input_size=out_channels)
|
140 |
-
self.dropout = nn.Dropout1d(p=dropout)
|
141 |
-
|
142 |
-
def forward(self, x):
|
143 |
-
return self.dropout(self.norm(self.activation(self.conv(x))))
|
144 |
-
|
145 |
-
class Res2NetBlock(torch.nn.Module):
|
146 |
-
def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1, dropout=0.0):
|
147 |
-
super().__init__()
|
148 |
-
assert in_channels % scale == 0
|
149 |
-
assert out_channels % scale == 0
|
150 |
-
in_channel = in_channels // scale
|
151 |
-
hidden_channel = out_channels // scale
|
152 |
-
self.blocks = nn.ModuleList([TDNNBlock(in_channel, hidden_channel, kernel_size=kernel_size, dilation=dilation, dropout=dropout) for _ in range(scale - 1)])
|
153 |
-
self.scale = scale
|
154 |
-
|
155 |
-
def forward(self, x):
|
156 |
-
y = []
|
157 |
-
|
158 |
-
for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
|
159 |
-
if i == 0: y_i = x_i
|
160 |
-
elif i == 1: y_i = self.blocks[i - 1](x_i)
|
161 |
-
else: y_i = self.blocks[i - 1](x_i + y_i)
|
162 |
-
|
163 |
-
y.append(y_i)
|
164 |
-
|
165 |
-
return torch.cat(y, dim=1)
|
166 |
-
|
167 |
-
class SEBlock(nn.Module):
|
168 |
-
def __init__(self, in_channels, se_channels, out_channels):
|
169 |
-
super().__init__()
|
170 |
-
|
171 |
-
self.conv1 = Conv1d(in_channels=in_channels, out_channels=se_channels, kernel_size=1)
|
172 |
-
self.relu = torch.nn.ReLU(inplace=True)
|
173 |
-
self.conv2 = Conv1d(in_channels=se_channels, out_channels=out_channels, kernel_size=1)
|
174 |
-
self.sigmoid = torch.nn.Sigmoid()
|
175 |
-
|
176 |
-
def forward(self, x, lengths=None):
|
177 |
-
L = x.shape[-1]
|
178 |
-
|
179 |
-
if lengths is not None:
|
180 |
-
mask = length_to_mask(lengths * L, max_len=L, device=x.device).unsqueeze(1)
|
181 |
-
s = (x * mask).sum(dim=2, keepdim=True) / mask.sum(dim=2, keepdim=True)
|
182 |
-
else: s = x.mean(dim=2, keepdim=True)
|
183 |
-
|
184 |
-
return self.sigmoid(self.conv2(self.relu(self.conv1(s)))) * x
|
185 |
-
|
186 |
-
class AttentiveStatisticsPooling(nn.Module):
|
187 |
-
def __init__(self, channels, attention_channels=128, global_context=True):
|
188 |
-
super().__init__()
|
189 |
-
self.eps = 1e-12
|
190 |
-
self.global_context = global_context
|
191 |
-
self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1) if global_context else TDNNBlock(channels, attention_channels, 1, 1)
|
192 |
-
self.tanh = nn.Tanh()
|
193 |
-
self.conv = Conv1d(in_channels=attention_channels, out_channels=channels, kernel_size=1)
|
194 |
-
|
195 |
-
def forward(self, x, lengths=None):
|
196 |
-
L = x.shape[-1]
|
197 |
-
|
198 |
-
def _compute_statistics(x, m, dim=2, eps=self.eps):
|
199 |
-
mean = (m * x).sum(dim)
|
200 |
-
return mean, torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
|
201 |
-
|
202 |
-
if lengths is None: lengths = torch.ones(x.shape[0], device=x.device)
|
203 |
-
mask = length_to_mask(lengths * L, max_len=L, device=x.device).unsqueeze(1)
|
204 |
-
|
205 |
-
if self.global_context:
|
206 |
-
mean, std = _compute_statistics(x, mask / mask.sum(dim=2, keepdim=True).float())
|
207 |
-
attn = torch.cat([x, mean.unsqueeze(2).repeat(1, 1, L), std.unsqueeze(2).repeat(1, 1, L)], dim=1)
|
208 |
-
else: attn = x
|
209 |
-
|
210 |
-
mean, std = _compute_statistics(x, F.softmax(self.conv(self.tanh(self.tdnn(attn))).masked_fill(mask == 0, float("-inf")), dim=2))
|
211 |
-
return torch.cat((mean, std), dim=1).unsqueeze(2)
|
212 |
-
|
213 |
-
class SERes2NetBlock(nn.Module):
|
214 |
-
def __init__(self, in_channels, out_channels, res2net_scale=8, se_channels=128, kernel_size=1, dilation=1, activation=torch.nn.ReLU, groups=1, dropout=0.0):
|
215 |
-
super().__init__()
|
216 |
-
self.out_channels = out_channels
|
217 |
-
self.tdnn1 = TDNNBlock(in_channels, out_channels, kernel_size=1, dilation=1, activation=activation, groups=groups, dropout=dropout)
|
218 |
-
self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation)
|
219 |
-
self.tdnn2 = TDNNBlock(out_channels, out_channels, kernel_size=1, dilation=1, activation=activation, groups=groups, dropout=dropout)
|
220 |
-
self.se_block = SEBlock(out_channels, se_channels, out_channels)
|
221 |
-
|
222 |
-
self.shortcut = None
|
223 |
-
if in_channels != out_channels: self.shortcut = Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
|
224 |
-
|
225 |
-
def forward(self, x, lengths=None):
|
226 |
-
residual = x
|
227 |
-
if self.shortcut: residual = self.shortcut(x)
|
228 |
-
|
229 |
-
return self.se_block(self.tdnn2(self.res2net_block(self.tdnn1(x))), lengths) + residual
|
230 |
-
|
231 |
-
class ECAPA_TDNN(torch.nn.Module):
|
232 |
-
def __init__(self, input_size, device="cpu", lin_neurons=192, activation=torch.nn.ReLU, channels=[512, 512, 512, 512, 1536], kernel_sizes=[5, 3, 3, 3, 1], dilations=[1, 2, 3, 4, 1], attention_channels=128, res2net_scale=8, se_channels=128, global_context=True, groups=[1, 1, 1, 1, 1], dropout=0.0):
|
233 |
-
super().__init__()
|
234 |
-
assert len(channels) == len(kernel_sizes)
|
235 |
-
assert len(channels) == len(dilations)
|
236 |
-
|
237 |
-
self.channels = channels
|
238 |
-
self.blocks = nn.ModuleList()
|
239 |
-
|
240 |
-
self.blocks.append(TDNNBlock(input_size, channels[0], kernel_sizes[0], dilations[0], activation, groups[0], dropout))
|
241 |
-
|
242 |
-
for i in range(1, len(channels) - 1):
|
243 |
-
self.blocks.append(SERes2NetBlock(channels[i - 1], channels[i], res2net_scale=res2net_scale, se_channels=se_channels, kernel_size=kernel_sizes[i], dilation=dilations[i], activation=activation, groups=groups[i], dropout=dropout))
|
244 |
-
|
245 |
-
self.mfa = TDNNBlock(channels[-2] * (len(channels) - 2), channels[-1], kernel_sizes[-1], dilations[-1], activation, groups=groups[-1], dropout=dropout)
|
246 |
-
self.asp = AttentiveStatisticsPooling(channels[-1], attention_channels=attention_channels, global_context=global_context)
|
247 |
-
self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2)
|
248 |
-
self.fc = Conv1d(in_channels=channels[-1] * 2, out_channels=lin_neurons, kernel_size=1)
|
249 |
-
|
250 |
-
def forward(self, x, lengths=None):
|
251 |
-
x = x.transpose(1, 2)
|
252 |
-
|
253 |
-
xl = []
|
254 |
-
for layer in self.blocks:
|
255 |
-
try:
|
256 |
-
x = layer(x, lengths=lengths)
|
257 |
-
except TypeError:
|
258 |
-
x = layer(x)
|
259 |
-
|
260 |
-
xl.append(x)
|
261 |
-
|
262 |
-
return self.fc(self.asp_bn(self.asp(self.mfa(torch.cat(xl[1:], dim=1)), lengths=lengths))).transpose(1, 2)
|
263 |
-
|
264 |
-
class Classifier(torch.nn.Module):
|
265 |
-
def __init__(self, input_size, device="cpu", lin_blocks=0, lin_neurons=192, out_neurons=1211):
|
266 |
-
super().__init__()
|
267 |
-
self.blocks = nn.ModuleList()
|
268 |
-
|
269 |
-
for _ in range(lin_blocks):
|
270 |
-
self.blocks.extend([_BatchNorm1d(input_size=input_size), Linear(input_size=input_size, n_neurons=lin_neurons)])
|
271 |
-
input_size = lin_neurons
|
272 |
-
|
273 |
-
self.weight = nn.Parameter(torch.FloatTensor(out_neurons, input_size, device=device))
|
274 |
-
nn.init.xavier_uniform_(self.weight)
|
275 |
-
|
276 |
-
def forward(self, x):
|
277 |
-
for layer in self.blocks:
|
278 |
-
x = layer(x)
|
279 |
-
|
280 |
-
return F.linear(F.normalize(x.squeeze(1)), F.normalize(self.weight)).unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/speaker_diarization/audio.py
DELETED
@@ -1,170 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import math
|
3 |
-
import random
|
4 |
-
import torchaudio
|
5 |
-
|
6 |
-
from io import IOBase
|
7 |
-
from torch.nn.functional import pad
|
8 |
-
|
9 |
-
def get_torchaudio_info(file, backend = None):
|
10 |
-
if not backend:
|
11 |
-
backends = (torchaudio.list_audio_backends())
|
12 |
-
backend = "soundfile" if "soundfile" in backends else backends[0]
|
13 |
-
|
14 |
-
info = torchaudio.info(file["audio"], backend=backend)
|
15 |
-
if isinstance(file["audio"], IOBase): file["audio"].seek(0)
|
16 |
-
|
17 |
-
return info
|
18 |
-
|
19 |
-
class Audio:
|
20 |
-
@staticmethod
|
21 |
-
def power_normalize(waveform):
|
22 |
-
return waveform / (waveform.square().mean(dim=-1, keepdim=True).sqrt() + 1e-8)
|
23 |
-
|
24 |
-
@staticmethod
|
25 |
-
def validate_file(file):
|
26 |
-
if isinstance(file, (str, os.PathLike)): file = {"audio": str(file), "uri": os.path.splitext(os.path.basename(file))[0]}
|
27 |
-
elif isinstance(file, IOBase): return {"audio": file, "uri": "stream"}
|
28 |
-
else: raise ValueError
|
29 |
-
|
30 |
-
if "waveform" in file:
|
31 |
-
waveform = file["waveform"]
|
32 |
-
if len(waveform.shape) != 2 or waveform.shape[0] > waveform.shape[1]: raise ValueError
|
33 |
-
|
34 |
-
sample_rate: int = file.get("sample_rate", None)
|
35 |
-
if sample_rate is None: raise ValueError
|
36 |
-
|
37 |
-
file.setdefault("uri", "waveform")
|
38 |
-
|
39 |
-
elif "audio" in file:
|
40 |
-
if isinstance(file["audio"], IOBase): return file
|
41 |
-
|
42 |
-
path = os.path.abspath(file["audio"])
|
43 |
-
file.setdefault("uri", os.path.splitext(os.path.basename(path))[0])
|
44 |
-
|
45 |
-
else: raise ValueError
|
46 |
-
|
47 |
-
return file
|
48 |
-
|
49 |
-
def __init__(self, sample_rate: int = None, mono=None, backend: str = None):
|
50 |
-
super().__init__()
|
51 |
-
self.sample_rate = sample_rate
|
52 |
-
self.mono = mono
|
53 |
-
|
54 |
-
if not backend:
|
55 |
-
backends = (torchaudio.list_audio_backends())
|
56 |
-
backend = "soundfile" if "soundfile" in backends else backends[0]
|
57 |
-
|
58 |
-
self.backend = backend
|
59 |
-
|
60 |
-
def downmix_and_resample(self, waveform, sample_rate):
|
61 |
-
num_channels = waveform.shape[0]
|
62 |
-
|
63 |
-
if num_channels > 1:
|
64 |
-
if self.mono == "random":
|
65 |
-
channel = random.randint(0, num_channels - 1)
|
66 |
-
waveform = waveform[channel : channel + 1]
|
67 |
-
elif self.mono == "downmix": waveform = waveform.mean(dim=0, keepdim=True)
|
68 |
-
|
69 |
-
if (self.sample_rate is not None) and (self.sample_rate != sample_rate):
|
70 |
-
waveform = torchaudio.functional.resample(waveform, sample_rate, self.sample_rate)
|
71 |
-
sample_rate = self.sample_rate
|
72 |
-
|
73 |
-
return waveform, sample_rate
|
74 |
-
|
75 |
-
def get_duration(self, file):
|
76 |
-
file = self.validate_file(file)
|
77 |
-
|
78 |
-
if "waveform" in file:
|
79 |
-
frames = len(file["waveform"].T)
|
80 |
-
sample_rate = file["sample_rate"]
|
81 |
-
else:
|
82 |
-
info = file["torchaudio.info"] if "torchaudio.info" in file else get_torchaudio_info(file, backend=self.backend)
|
83 |
-
frames = info.num_frames
|
84 |
-
sample_rate = info.sample_rate
|
85 |
-
|
86 |
-
return frames / sample_rate
|
87 |
-
|
88 |
-
def get_num_samples(self, duration, sample_rate = None):
|
89 |
-
sample_rate = sample_rate or self.sample_rate
|
90 |
-
if sample_rate is None: raise ValueError
|
91 |
-
|
92 |
-
return math.floor(duration * sample_rate)
|
93 |
-
|
94 |
-
def __call__(self, file):
|
95 |
-
file = self.validate_file(file)
|
96 |
-
|
97 |
-
if "waveform" in file:
|
98 |
-
waveform = file["waveform"]
|
99 |
-
sample_rate = file["sample_rate"]
|
100 |
-
elif "audio" in file:
|
101 |
-
waveform, sample_rate = torchaudio.load(file["audio"], backend=self.backend)
|
102 |
-
if isinstance(file["audio"], IOBase): file["audio"].seek(0)
|
103 |
-
|
104 |
-
channel = file.get("channel", None)
|
105 |
-
if channel is not None: waveform = waveform[channel : channel + 1]
|
106 |
-
|
107 |
-
return self.downmix_and_resample(waveform, sample_rate)
|
108 |
-
|
109 |
-
def crop(self, file, segment, duration = None, mode="raise"):
|
110 |
-
file = self.validate_file(file)
|
111 |
-
|
112 |
-
if "waveform" in file:
|
113 |
-
waveform = file["waveform"]
|
114 |
-
frames = waveform.shape[1]
|
115 |
-
sample_rate = file["sample_rate"]
|
116 |
-
elif "torchaudio.info" in file:
|
117 |
-
info = file["torchaudio.info"]
|
118 |
-
frames = info.num_frames
|
119 |
-
sample_rate = info.sample_rate
|
120 |
-
else:
|
121 |
-
info = get_torchaudio_info(file, backend=self.backend)
|
122 |
-
frames = info.num_frames
|
123 |
-
sample_rate = info.sample_rate
|
124 |
-
|
125 |
-
channel = file.get("channel", None)
|
126 |
-
start_frame = math.floor(segment.start * sample_rate)
|
127 |
-
|
128 |
-
if duration:
|
129 |
-
num_frames = math.floor(duration * sample_rate)
|
130 |
-
end_frame = start_frame + num_frames
|
131 |
-
else:
|
132 |
-
end_frame = math.floor(segment.end * sample_rate)
|
133 |
-
num_frames = end_frame - start_frame
|
134 |
-
|
135 |
-
if mode == "raise":
|
136 |
-
if num_frames > frames: raise ValueError
|
137 |
-
|
138 |
-
if end_frame > frames + math.ceil(0.001 * sample_rate): raise ValueError
|
139 |
-
else:
|
140 |
-
end_frame = min(end_frame, frames)
|
141 |
-
start_frame = end_frame - num_frames
|
142 |
-
|
143 |
-
if start_frame < 0: raise ValueError
|
144 |
-
elif mode == "pad":
|
145 |
-
pad_start = -min(0, start_frame)
|
146 |
-
pad_end = max(end_frame, frames) - frames
|
147 |
-
|
148 |
-
start_frame = max(0, start_frame)
|
149 |
-
end_frame = min(end_frame, frames)
|
150 |
-
|
151 |
-
num_frames = end_frame - start_frame
|
152 |
-
|
153 |
-
if "waveform" in file: data = file["waveform"][:, start_frame:end_frame]
|
154 |
-
else:
|
155 |
-
try:
|
156 |
-
data, _ = torchaudio.load(file["audio"], frame_offset=start_frame, num_frames=num_frames, backend=self.backend)
|
157 |
-
if isinstance(file["audio"], IOBase): file["audio"].seek(0)
|
158 |
-
except RuntimeError:
|
159 |
-
if isinstance(file["audio"], IOBase): raise RuntimeError
|
160 |
-
|
161 |
-
waveform, sample_rate = self.__call__(file)
|
162 |
-
data = waveform[:, start_frame:end_frame]
|
163 |
-
|
164 |
-
file["waveform"] = waveform
|
165 |
-
file["sample_rate"] = sample_rate
|
166 |
-
|
167 |
-
if channel is not None: data = data[channel : channel + 1, :]
|
168 |
-
if mode == "pad": data = pad(data, (pad_start, pad_end))
|
169 |
-
|
170 |
-
return self.downmix_and_resample(data, sample_rate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/speaker_diarization/embedding.py
DELETED
@@ -1,90 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import torch
|
4 |
-
|
5 |
-
import numpy as np
|
6 |
-
import torch.nn.functional as F
|
7 |
-
|
8 |
-
from functools import cached_property
|
9 |
-
from torch.nn.utils.rnn import pad_sequence
|
10 |
-
|
11 |
-
sys.path.append(os.getcwd())
|
12 |
-
|
13 |
-
from main.library.speaker_diarization.speechbrain import EncoderClassifier
|
14 |
-
|
15 |
-
class BaseInference:
|
16 |
-
pass
|
17 |
-
|
18 |
-
class SpeechBrainPretrainedSpeakerEmbedding(BaseInference):
|
19 |
-
def __init__(self, embedding = "assets/models/speaker_diarization/models/speechbrain", device = None):
|
20 |
-
super().__init__()
|
21 |
-
|
22 |
-
self.embedding = embedding
|
23 |
-
self.device = device or torch.device("cpu")
|
24 |
-
self.classifier_ = EncoderClassifier.from_hparams(source=self.embedding, run_opts={"device": self.device})
|
25 |
-
|
26 |
-
def to(self, device):
|
27 |
-
if not isinstance(device, torch.device): raise TypeError
|
28 |
-
|
29 |
-
self.classifier_ = EncoderClassifier.from_hparams(source=self.embedding, run_opts={"device": device})
|
30 |
-
self.device = device
|
31 |
-
return self
|
32 |
-
|
33 |
-
@cached_property
|
34 |
-
def sample_rate(self):
|
35 |
-
return self.classifier_.audio_normalizer.sample_rate
|
36 |
-
|
37 |
-
@cached_property
|
38 |
-
def dimension(self):
|
39 |
-
*_, dimension = self.classifier_.encode_batch(torch.rand(1, 16000).to(self.device)).shape
|
40 |
-
return dimension
|
41 |
-
|
42 |
-
@cached_property
|
43 |
-
def metric(self):
|
44 |
-
return "cosine"
|
45 |
-
|
46 |
-
@cached_property
|
47 |
-
def min_num_samples(self):
|
48 |
-
with torch.inference_mode():
|
49 |
-
lower, upper = 2, round(0.5 * self.sample_rate)
|
50 |
-
middle = (lower + upper) // 2
|
51 |
-
|
52 |
-
while lower + 1 < upper:
|
53 |
-
try:
|
54 |
-
_ = self.classifier_.encode_batch(torch.randn(1, middle).to(self.device))
|
55 |
-
upper = middle
|
56 |
-
except RuntimeError:
|
57 |
-
lower = middle
|
58 |
-
|
59 |
-
middle = (lower + upper) // 2
|
60 |
-
|
61 |
-
return upper
|
62 |
-
|
63 |
-
def __call__(self, waveforms, masks = None):
|
64 |
-
batch_size, num_channels, num_samples = waveforms.shape
|
65 |
-
assert num_channels == 1
|
66 |
-
|
67 |
-
waveforms = waveforms.squeeze(dim=1)
|
68 |
-
|
69 |
-
if masks is None:
|
70 |
-
signals = waveforms.squeeze(dim=1)
|
71 |
-
wav_lens = signals.shape[1] * torch.ones(batch_size)
|
72 |
-
else:
|
73 |
-
batch_size_masks, _ = masks.shape
|
74 |
-
assert batch_size == batch_size_masks
|
75 |
-
|
76 |
-
imasks = F.interpolate(masks.unsqueeze(dim=1), size=num_samples, mode="nearest").squeeze(dim=1) > 0.5
|
77 |
-
signals = pad_sequence([waveform[imask].contiguous() for waveform, imask in zip(waveforms, imasks)], batch_first=True)
|
78 |
-
wav_lens = imasks.sum(dim=1)
|
79 |
-
|
80 |
-
max_len = wav_lens.max()
|
81 |
-
if max_len < self.min_num_samples: return np.nan * np.zeros((batch_size, self.dimension))
|
82 |
-
|
83 |
-
too_short = wav_lens < self.min_num_samples
|
84 |
-
wav_lens = wav_lens / max_len
|
85 |
-
wav_lens[too_short] = 1.0
|
86 |
-
|
87 |
-
embeddings = (self.classifier_.encode_batch(signals, wav_lens=wav_lens).squeeze(dim=1).cpu().numpy())
|
88 |
-
embeddings[too_short.cpu().numpy()] = np.nan
|
89 |
-
|
90 |
-
return embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/speaker_diarization/encoder.py
DELETED
@@ -1,250 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import ast
|
4 |
-
import torch
|
5 |
-
import itertools
|
6 |
-
import collections
|
7 |
-
|
8 |
-
sys.path.append(os.getcwd())
|
9 |
-
|
10 |
-
from main.library.speaker_diarization.speechbrain import if_main_process, ddp_barrier
|
11 |
-
from main.library.speaker_diarization.features import register_checkpoint_hooks, mark_as_saver, mark_as_loader
|
12 |
-
|
13 |
-
DEFAULT_UNK = "<unk>"
|
14 |
-
DEFAULT_BOS = "<bos>"
|
15 |
-
DEFAULT_EOS = "<eos>"
|
16 |
-
DEFAULT_BLANK = "<blank>"
|
17 |
-
|
18 |
-
@register_checkpoint_hooks
|
19 |
-
class CategoricalEncoder:
|
20 |
-
VALUE_SEPARATOR = " => "
|
21 |
-
EXTRAS_SEPARATOR = "================\n"
|
22 |
-
|
23 |
-
def __init__(self, starting_index=0, **special_labels):
|
24 |
-
self.lab2ind = {}
|
25 |
-
self.ind2lab = {}
|
26 |
-
self.starting_index = starting_index
|
27 |
-
self.handle_special_labels(special_labels)
|
28 |
-
|
29 |
-
def handle_special_labels(self, special_labels):
|
30 |
-
if "unk_label" in special_labels: self.add_unk(special_labels["unk_label"])
|
31 |
-
|
32 |
-
def __len__(self):
|
33 |
-
return len(self.lab2ind)
|
34 |
-
|
35 |
-
@classmethod
|
36 |
-
def from_saved(cls, path):
|
37 |
-
obj = cls()
|
38 |
-
obj.load(path)
|
39 |
-
return obj
|
40 |
-
|
41 |
-
def update_from_iterable(self, iterable, sequence_input=False):
|
42 |
-
label_iterator = itertools.chain.from_iterable(iterable) if sequence_input else iter(iterable)
|
43 |
-
for label in label_iterator:
|
44 |
-
self.ensure_label(label)
|
45 |
-
|
46 |
-
def update_from_didataset(self, didataset, output_key, sequence_input=False):
|
47 |
-
with didataset.output_keys_as([output_key]):
|
48 |
-
self.update_from_iterable((data_point[output_key] for data_point in didataset), sequence_input=sequence_input)
|
49 |
-
|
50 |
-
def limited_labelset_from_iterable(self, iterable, sequence_input=False, n_most_common=None, min_count=1):
|
51 |
-
label_iterator = itertools.chain.from_iterable(iterable) if sequence_input else iter(iterable)
|
52 |
-
counts = collections.Counter(label_iterator)
|
53 |
-
|
54 |
-
for label, count in counts.most_common(n_most_common):
|
55 |
-
if count < min_count: break
|
56 |
-
self.add_label(label)
|
57 |
-
|
58 |
-
return counts
|
59 |
-
|
60 |
-
def load_or_create(self, path, from_iterables=[], from_didatasets=[], sequence_input=False, output_key=None, special_labels={}):
|
61 |
-
try:
|
62 |
-
if if_main_process():
|
63 |
-
if not self.load_if_possible(path):
|
64 |
-
for iterable in from_iterables:
|
65 |
-
self.update_from_iterable(iterable, sequence_input)
|
66 |
-
|
67 |
-
for didataset in from_didatasets:
|
68 |
-
if output_key is None: raise ValueError
|
69 |
-
self.update_from_didataset(didataset, output_key, sequence_input)
|
70 |
-
|
71 |
-
self.handle_special_labels(special_labels)
|
72 |
-
self.save(path)
|
73 |
-
finally:
|
74 |
-
ddp_barrier()
|
75 |
-
self.load(path)
|
76 |
-
|
77 |
-
def add_label(self, label):
|
78 |
-
if label in self.lab2ind: raise KeyError
|
79 |
-
index = self._next_index()
|
80 |
-
|
81 |
-
self.lab2ind[label] = index
|
82 |
-
self.ind2lab[index] = label
|
83 |
-
|
84 |
-
return index
|
85 |
-
|
86 |
-
def ensure_label(self, label):
|
87 |
-
if label in self.lab2ind: return self.lab2ind[label]
|
88 |
-
else: return self.add_label(label)
|
89 |
-
|
90 |
-
def insert_label(self, label, index):
|
91 |
-
if label in self.lab2ind: raise KeyError
|
92 |
-
else: self.enforce_label(label, index)
|
93 |
-
|
94 |
-
def enforce_label(self, label, index):
|
95 |
-
index = int(index)
|
96 |
-
|
97 |
-
if label in self.lab2ind:
|
98 |
-
if index == self.lab2ind[label]: return
|
99 |
-
else: del self.ind2lab[self.lab2ind[label]]
|
100 |
-
|
101 |
-
if index in self.ind2lab:
|
102 |
-
saved_label = self.ind2lab[index]
|
103 |
-
moving_other = True
|
104 |
-
else: moving_other = False
|
105 |
-
|
106 |
-
self.lab2ind[label] = index
|
107 |
-
self.ind2lab[index] = label
|
108 |
-
|
109 |
-
if moving_other:
|
110 |
-
new_index = self._next_index()
|
111 |
-
self.lab2ind[saved_label] = new_index
|
112 |
-
self.ind2lab[new_index] = saved_label
|
113 |
-
|
114 |
-
def add_unk(self, unk_label=DEFAULT_UNK):
|
115 |
-
self.unk_label = unk_label
|
116 |
-
return self.add_label(unk_label)
|
117 |
-
|
118 |
-
def _next_index(self):
|
119 |
-
index = self.starting_index
|
120 |
-
while index in self.ind2lab:
|
121 |
-
index += 1
|
122 |
-
|
123 |
-
return index
|
124 |
-
|
125 |
-
def is_continuous(self):
|
126 |
-
indices = sorted(self.ind2lab.keys())
|
127 |
-
return self.starting_index in indices and all(j - i == 1 for i, j in zip(indices[:-1], indices[1:]))
|
128 |
-
|
129 |
-
def encode_label(self, label, allow_unk=True):
|
130 |
-
self._assert_len()
|
131 |
-
|
132 |
-
try:
|
133 |
-
return self.lab2ind[label]
|
134 |
-
except KeyError:
|
135 |
-
if hasattr(self, "unk_label") and allow_unk: return self.lab2ind[self.unk_label]
|
136 |
-
elif hasattr(self, "unk_label") and not allow_unk: raise KeyError
|
137 |
-
elif not hasattr(self, "unk_label") and allow_unk: raise KeyError
|
138 |
-
else: raise KeyError
|
139 |
-
|
140 |
-
def encode_label_torch(self, label, allow_unk=True):
|
141 |
-
return torch.LongTensor([self.encode_label(label, allow_unk)])
|
142 |
-
|
143 |
-
def encode_sequence(self, sequence, allow_unk=True):
|
144 |
-
self._assert_len()
|
145 |
-
return [self.encode_label(label, allow_unk) for label in sequence]
|
146 |
-
|
147 |
-
def encode_sequence_torch(self, sequence, allow_unk=True):
|
148 |
-
return torch.LongTensor([self.encode_label(label, allow_unk) for label in sequence])
|
149 |
-
|
150 |
-
def decode_torch(self, x):
|
151 |
-
self._assert_len()
|
152 |
-
decoded = []
|
153 |
-
|
154 |
-
if x.ndim == 1:
|
155 |
-
for element in x:
|
156 |
-
decoded.append(self.ind2lab[int(element)])
|
157 |
-
else:
|
158 |
-
for subtensor in x:
|
159 |
-
decoded.append(self.decode_torch(subtensor))
|
160 |
-
|
161 |
-
return decoded
|
162 |
-
|
163 |
-
def decode_ndim(self, x):
|
164 |
-
self._assert_len()
|
165 |
-
try:
|
166 |
-
decoded = []
|
167 |
-
for subtensor in x:
|
168 |
-
decoded.append(self.decode_ndim(subtensor))
|
169 |
-
|
170 |
-
return decoded
|
171 |
-
except TypeError:
|
172 |
-
return self.ind2lab[int(x)]
|
173 |
-
|
174 |
-
@mark_as_saver
|
175 |
-
def save(self, path):
|
176 |
-
self._save_literal(path, self.lab2ind, self._get_extras())
|
177 |
-
|
178 |
-
def load(self, path):
|
179 |
-
lab2ind, ind2lab, extras = self._load_literal(path)
|
180 |
-
self.lab2ind = lab2ind
|
181 |
-
self.ind2lab = ind2lab
|
182 |
-
self._set_extras(extras)
|
183 |
-
|
184 |
-
@mark_as_loader
|
185 |
-
def load_if_possible(self, path, end_of_epoch=False):
|
186 |
-
del end_of_epoch
|
187 |
-
|
188 |
-
try:
|
189 |
-
self.load(path)
|
190 |
-
except FileNotFoundError:
|
191 |
-
return False
|
192 |
-
except (ValueError, SyntaxError):
|
193 |
-
return False
|
194 |
-
|
195 |
-
return True
|
196 |
-
|
197 |
-
def expect_len(self, expected_len):
|
198 |
-
self.expected_len = expected_len
|
199 |
-
|
200 |
-
def ignore_len(self):
|
201 |
-
self.expected_len = None
|
202 |
-
|
203 |
-
def _assert_len(self):
|
204 |
-
if hasattr(self, "expected_len"):
|
205 |
-
if self.expected_len is None: return
|
206 |
-
if len(self) != self.expected_len: raise RuntimeError
|
207 |
-
else:
|
208 |
-
self.ignore_len()
|
209 |
-
return
|
210 |
-
|
211 |
-
def _get_extras(self):
|
212 |
-
extras = {"starting_index": self.starting_index}
|
213 |
-
if hasattr(self, "unk_label"): extras["unk_label"] = self.unk_label
|
214 |
-
|
215 |
-
return extras
|
216 |
-
|
217 |
-
def _set_extras(self, extras):
|
218 |
-
if "unk_label" in extras: self.unk_label = extras["unk_label"]
|
219 |
-
self.starting_index = extras["starting_index"]
|
220 |
-
|
221 |
-
@staticmethod
|
222 |
-
def _save_literal(path, lab2ind, extras):
|
223 |
-
with open(path, "w", encoding="utf-8") as f:
|
224 |
-
for label, ind in lab2ind.items():
|
225 |
-
f.write(repr(label) + CategoricalEncoder.VALUE_SEPARATOR + str(ind) + "\n")
|
226 |
-
|
227 |
-
f.write(CategoricalEncoder.EXTRAS_SEPARATOR)
|
228 |
-
|
229 |
-
for key, value in extras.items():
|
230 |
-
f.write(repr(key) + CategoricalEncoder.VALUE_SEPARATOR + repr(value) + "\n")
|
231 |
-
|
232 |
-
f.flush()
|
233 |
-
|
234 |
-
@staticmethod
|
235 |
-
def _load_literal(path):
|
236 |
-
lab2ind, ind2lab, extras = {}, {}, {}
|
237 |
-
|
238 |
-
with open(path, encoding="utf-8") as f:
|
239 |
-
for line in f:
|
240 |
-
if line == CategoricalEncoder.EXTRAS_SEPARATOR: break
|
241 |
-
literal, ind = line.strip().split(CategoricalEncoder.VALUE_SEPARATOR, maxsplit=1)
|
242 |
-
label = ast.literal_eval(literal)
|
243 |
-
lab2ind[label] = int(ind)
|
244 |
-
ind2lab[ind] = label
|
245 |
-
|
246 |
-
for line in f:
|
247 |
-
literal_key, literal_value = line.strip().split(CategoricalEncoder.VALUE_SEPARATOR, maxsplit=1)
|
248 |
-
extras[ast.literal_eval(literal_key)] = ast.literal_eval(literal_value)
|
249 |
-
|
250 |
-
return lab2ind, ind2lab, extras
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/speaker_diarization/features.py
DELETED
@@ -1,520 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import math
|
4 |
-
import torch
|
5 |
-
import inspect
|
6 |
-
import functools
|
7 |
-
|
8 |
-
sys.path.append(os.getcwd())
|
9 |
-
|
10 |
-
from main.library.speaker_diarization.speechbrain import MAIN_PROC_ONLY, is_distributed_initialized, main_process_only
|
11 |
-
|
12 |
-
KEYS_MAPPING = {".mutihead_attn": ".multihead_attn", ".convs_intermedite": ".convs_intermediate"}
|
13 |
-
|
14 |
-
def map_old_state_dict_weights(state_dict, mapping):
|
15 |
-
for replacement_old, replacement_new in mapping.items():
|
16 |
-
for old_key in list(state_dict.keys()):
|
17 |
-
if replacement_old in old_key: state_dict[old_key.replace(replacement_old, replacement_new)] = state_dict.pop(old_key)
|
18 |
-
|
19 |
-
return state_dict
|
20 |
-
|
21 |
-
def hook_on_loading_state_dict_checkpoint(state_dict):
|
22 |
-
return map_old_state_dict_weights(state_dict, KEYS_MAPPING)
|
23 |
-
|
24 |
-
def torch_patched_state_dict_load(path, device="cpu"):
|
25 |
-
return hook_on_loading_state_dict_checkpoint(torch.load(path, map_location=device))
|
26 |
-
|
27 |
-
@main_process_only
|
28 |
-
def torch_save(obj, path):
|
29 |
-
state_dict = obj.state_dict()
|
30 |
-
torch.save(state_dict, path)
|
31 |
-
|
32 |
-
def torch_recovery(obj, path, end_of_epoch):
|
33 |
-
del end_of_epoch
|
34 |
-
|
35 |
-
state_dict = torch_patched_state_dict_load(path, "cpu")
|
36 |
-
try:
|
37 |
-
obj.load_state_dict(state_dict, strict=True)
|
38 |
-
except TypeError:
|
39 |
-
obj.load_state_dict(state_dict)
|
40 |
-
|
41 |
-
def torch_parameter_transfer(obj, path):
|
42 |
-
incompatible_keys = obj.load_state_dict(torch_patched_state_dict_load(path, "cpu"), strict=False)
|
43 |
-
|
44 |
-
for missing_key in incompatible_keys.missing_keys:
|
45 |
-
pass
|
46 |
-
for unexpected_key in incompatible_keys.unexpected_keys:
|
47 |
-
pass
|
48 |
-
|
49 |
-
WEAKREF_MARKER = "WEAKREF"
|
50 |
-
|
51 |
-
def _cycliclrsaver(obj, path):
|
52 |
-
state_dict = obj.state_dict()
|
53 |
-
if state_dict.get("_scale_fn_ref") is not None: state_dict["_scale_fn_ref"] = WEAKREF_MARKER
|
54 |
-
|
55 |
-
torch.save(state_dict, path)
|
56 |
-
|
57 |
-
def _cycliclrloader(obj, path, end_of_epoch):
|
58 |
-
del end_of_epoch
|
59 |
-
|
60 |
-
try:
|
61 |
-
obj.load_state_dict(torch.load(path, map_location="cpu"), strict=True)
|
62 |
-
except TypeError:
|
63 |
-
obj.load_state_dict(torch.load(path, map_location="cpu"))
|
64 |
-
|
65 |
-
DEFAULT_LOAD_HOOKS = {torch.nn.Module: torch_recovery, torch.optim.Optimizer: torch_recovery, torch.optim.lr_scheduler.ReduceLROnPlateau: torch_recovery, torch.cuda.amp.grad_scaler.GradScaler: torch_recovery}
|
66 |
-
DEFAULT_SAVE_HOOKS = { torch.nn.Module: torch_save, torch.optim.Optimizer: torch_save, torch.optim.lr_scheduler.ReduceLROnPlateau: torch_save, torch.cuda.amp.grad_scaler.GradScaler: torch_save}
|
67 |
-
DEFAULT_LOAD_HOOKS[torch.optim.lr_scheduler.LRScheduler] = torch_recovery
|
68 |
-
DEFAULT_SAVE_HOOKS[torch.optim.lr_scheduler.LRScheduler] = torch_save
|
69 |
-
DEFAULT_TRANSFER_HOOKS = {torch.nn.Module: torch_parameter_transfer}
|
70 |
-
DEFAULT_SAVE_HOOKS[torch.optim.lr_scheduler.CyclicLR] = _cycliclrsaver
|
71 |
-
DEFAULT_LOAD_HOOKS[torch.optim.lr_scheduler.CyclicLR] = _cycliclrloader
|
72 |
-
|
73 |
-
def register_checkpoint_hooks(cls, save_on_main_only=True):
|
74 |
-
global DEFAULT_LOAD_HOOKS, DEFAULT_SAVE_HOOKS, DEFAULT_TRANSFER_HOOKS
|
75 |
-
|
76 |
-
for name, method in cls.__dict__.items():
|
77 |
-
if hasattr(method, "_speechbrain_saver"): DEFAULT_SAVE_HOOKS[cls] = main_process_only(method) if save_on_main_only else method
|
78 |
-
if hasattr(method, "_speechbrain_loader"): DEFAULT_LOAD_HOOKS[cls] = method
|
79 |
-
if hasattr(method, "_speechbrain_transfer"): DEFAULT_TRANSFER_HOOKS[cls] = method
|
80 |
-
|
81 |
-
return cls
|
82 |
-
|
83 |
-
def mark_as_saver(method):
|
84 |
-
sig = inspect.signature(method)
|
85 |
-
|
86 |
-
try:
|
87 |
-
sig.bind(object(), os.path.join("testpath"))
|
88 |
-
except TypeError:
|
89 |
-
raise TypeError
|
90 |
-
|
91 |
-
method._speechbrain_saver = True
|
92 |
-
return method
|
93 |
-
|
94 |
-
def mark_as_transfer(method):
|
95 |
-
sig = inspect.signature(method)
|
96 |
-
|
97 |
-
try:
|
98 |
-
sig.bind(object(), os.path.join("testpath"))
|
99 |
-
except TypeError:
|
100 |
-
raise TypeError
|
101 |
-
|
102 |
-
method._speechbrain_transfer = True
|
103 |
-
return method
|
104 |
-
|
105 |
-
def mark_as_loader(method):
|
106 |
-
sig = inspect.signature(method)
|
107 |
-
|
108 |
-
try:
|
109 |
-
sig.bind(object(), os.path.join("testpath"), True)
|
110 |
-
except TypeError:
|
111 |
-
raise TypeError
|
112 |
-
|
113 |
-
method._speechbrain_loader = True
|
114 |
-
return method
|
115 |
-
|
116 |
-
def ddp_all_reduce(communication_object, reduce_op):
|
117 |
-
if MAIN_PROC_ONLY >= 1 or not is_distributed_initialized(): return communication_object
|
118 |
-
torch.distributed.all_reduce(communication_object, op=reduce_op)
|
119 |
-
|
120 |
-
return communication_object
|
121 |
-
|
122 |
-
def fwd_default_precision(fwd = None, cast_inputs = torch.float32):
|
123 |
-
if fwd is None: return functools.partial(fwd_default_precision, cast_inputs=cast_inputs)
|
124 |
-
|
125 |
-
wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)
|
126 |
-
|
127 |
-
@functools.wraps(fwd)
|
128 |
-
def wrapper(*args, force_allow_autocast = False, **kwargs):
|
129 |
-
return fwd(*args, **kwargs) if force_allow_autocast else wrapped_fwd(*args, **kwargs)
|
130 |
-
|
131 |
-
return wrapper
|
132 |
-
|
133 |
-
def spectral_magnitude(stft, power = 1, log = False, eps = 1e-14):
|
134 |
-
spectr = stft.pow(2).sum(-1)
|
135 |
-
|
136 |
-
if power < 1: spectr = spectr + eps
|
137 |
-
spectr = spectr.pow(power)
|
138 |
-
|
139 |
-
if log: return torch.log(spectr + eps)
|
140 |
-
return spectr
|
141 |
-
|
142 |
-
class Filterbank(torch.nn.Module):
|
143 |
-
def __init__(self, n_mels=40, log_mel=True, filter_shape="triangular", f_min=0, f_max=8000, n_fft=400, sample_rate=16000, power_spectrogram=2, amin=1e-10, ref_value=1.0, top_db=80.0, param_change_factor=1.0, param_rand_factor=0.0, freeze=True):
|
144 |
-
super().__init__()
|
145 |
-
self.n_mels = n_mels
|
146 |
-
self.log_mel = log_mel
|
147 |
-
self.filter_shape = filter_shape
|
148 |
-
self.f_min = f_min
|
149 |
-
self.f_max = f_max
|
150 |
-
self.n_fft = n_fft
|
151 |
-
self.sample_rate = sample_rate
|
152 |
-
self.power_spectrogram = power_spectrogram
|
153 |
-
self.amin = amin
|
154 |
-
self.ref_value = ref_value
|
155 |
-
self.top_db = top_db
|
156 |
-
self.freeze = freeze
|
157 |
-
self.n_stft = self.n_fft // 2 + 1
|
158 |
-
self.db_multiplier = math.log10(max(self.amin, self.ref_value))
|
159 |
-
self.device_inp = torch.device("cpu")
|
160 |
-
self.param_change_factor = param_change_factor
|
161 |
-
self.param_rand_factor = param_rand_factor
|
162 |
-
self.multiplier = 10 if self.power_spectrogram == 2 else 20
|
163 |
-
|
164 |
-
hz = self._to_hz(torch.linspace(self._to_mel(self.f_min), self._to_mel(self.f_max), self.n_mels + 2))
|
165 |
-
|
166 |
-
band = hz[1:] - hz[:-1]
|
167 |
-
self.band = band[:-1]
|
168 |
-
self.f_central = hz[1:-1]
|
169 |
-
|
170 |
-
if not self.freeze:
|
171 |
-
self.f_central = torch.nn.Parameter(self.f_central / (self.sample_rate * self.param_change_factor))
|
172 |
-
self.band = torch.nn.Parameter(self.band / (self.sample_rate * self.param_change_factor))
|
173 |
-
|
174 |
-
self.all_freqs_mat = torch.linspace(0, self.sample_rate // 2, self.n_stft).repeat(self.f_central.shape[0], 1)
|
175 |
-
|
176 |
-
def forward(self, spectrogram):
|
177 |
-
f_central_mat = self.f_central.repeat(self.all_freqs_mat.shape[1], 1).transpose(0, 1)
|
178 |
-
band_mat = self.band.repeat(self.all_freqs_mat.shape[1], 1).transpose(0, 1)
|
179 |
-
|
180 |
-
if not self.freeze:
|
181 |
-
f_central_mat = f_central_mat * (self.sample_rate * self.param_change_factor * self.param_change_factor)
|
182 |
-
band_mat = band_mat * (self.sample_rate * self.param_change_factor * self.param_change_factor)
|
183 |
-
elif self.param_rand_factor != 0 and self.training:
|
184 |
-
rand_change = (1.0 + torch.rand(2) * 2 * self.param_rand_factor - self.param_rand_factor)
|
185 |
-
f_central_mat = f_central_mat * rand_change[0]
|
186 |
-
band_mat = band_mat * rand_change[1]
|
187 |
-
|
188 |
-
fbank_matrix = self._create_fbank_matrix(f_central_mat, band_mat).to(spectrogram.device)
|
189 |
-
sp_shape = spectrogram.shape
|
190 |
-
if len(sp_shape) == 4: spectrogram = spectrogram.permute(0, 3, 1, 2).reshape(sp_shape[0] * sp_shape[3], sp_shape[1], sp_shape[2])
|
191 |
-
|
192 |
-
fbanks = torch.matmul(spectrogram, fbank_matrix)
|
193 |
-
if self.log_mel: fbanks = self._amplitude_to_DB(fbanks)
|
194 |
-
|
195 |
-
if len(sp_shape) == 4:
|
196 |
-
fb_shape = fbanks.shape
|
197 |
-
fbanks = fbanks.reshape(sp_shape[0], sp_shape[3], fb_shape[1], fb_shape[2]).permute(0, 2, 3, 1)
|
198 |
-
|
199 |
-
return fbanks
|
200 |
-
|
201 |
-
@staticmethod
|
202 |
-
def _to_mel(hz):
|
203 |
-
return 2595 * math.log10(1 + hz / 700)
|
204 |
-
|
205 |
-
@staticmethod
|
206 |
-
def _to_hz(mel):
|
207 |
-
return 700 * (10 ** (mel / 2595) - 1)
|
208 |
-
|
209 |
-
def _triangular_filters(self, all_freqs, f_central, band):
|
210 |
-
slope = (all_freqs - f_central) / band
|
211 |
-
return torch.max(torch.zeros(1, device=self.device_inp), torch.min(slope + 1.0, -slope + 1.0)).transpose(0, 1)
|
212 |
-
|
213 |
-
def _rectangular_filters(self, all_freqs, f_central, band):
|
214 |
-
left_side = right_size = all_freqs.ge(f_central - band)
|
215 |
-
right_size = all_freqs.le(f_central + band)
|
216 |
-
|
217 |
-
return (left_side * right_size).float().transpose(0, 1)
|
218 |
-
|
219 |
-
def _gaussian_filters(self, all_freqs, f_central, band, smooth_factor=torch.tensor(2)):
|
220 |
-
return torch.exp(-0.5 * ((all_freqs - f_central) / (band / smooth_factor)) ** 2).transpose(0, 1)
|
221 |
-
|
222 |
-
def _create_fbank_matrix(self, f_central_mat, band_mat):
|
223 |
-
if self.filter_shape == "triangular": fbank_matrix = self._triangular_filters(self.all_freqs_mat, f_central_mat, band_mat)
|
224 |
-
elif self.filter_shape == "rectangular": fbank_matrix = self._rectangular_filters(self.all_freqs_mat, f_central_mat, band_mat)
|
225 |
-
else: fbank_matrix = self._gaussian_filters(self.all_freqs_mat, f_central_mat, band_mat)
|
226 |
-
|
227 |
-
return fbank_matrix
|
228 |
-
|
229 |
-
def _amplitude_to_DB(self, x):
|
230 |
-
x_db = self.multiplier * torch.log10(torch.clamp(x, min=self.amin))
|
231 |
-
x_db -= self.multiplier * self.db_multiplier
|
232 |
-
|
233 |
-
return torch.max(x_db, (x_db.amax(dim=(-2, -1)) - self.top_db).view(x_db.shape[0], 1, 1))
|
234 |
-
|
235 |
-
class ContextWindow(torch.nn.Module):
|
236 |
-
def __init__(self, left_frames=0, right_frames=0):
|
237 |
-
super().__init__()
|
238 |
-
self.left_frames = left_frames
|
239 |
-
self.right_frames = right_frames
|
240 |
-
self.context_len = self.left_frames + self.right_frames + 1
|
241 |
-
self.kernel_len = 2 * max(self.left_frames, self.right_frames) + 1
|
242 |
-
self.kernel = torch.eye(self.context_len, self.kernel_len)
|
243 |
-
|
244 |
-
if self.right_frames > self.left_frames: self.kernel = torch.roll(self.kernel, self.right_frames - self.left_frames, 1)
|
245 |
-
self.first_call = True
|
246 |
-
|
247 |
-
def forward(self, x):
|
248 |
-
x = x.transpose(1, 2)
|
249 |
-
if self.first_call:
|
250 |
-
self.first_call = False
|
251 |
-
self.kernel = (self.kernel.repeat(x.shape[1], 1, 1).view(x.shape[1] * self.context_len, self.kernel_len).unsqueeze(1))
|
252 |
-
|
253 |
-
or_shape = x.shape
|
254 |
-
if len(or_shape) == 4: x = x.reshape(or_shape[0] * or_shape[2], or_shape[1], or_shape[3])
|
255 |
-
|
256 |
-
cw_x = torch.nn.functional.conv1d(x, self.kernel.to(x.device), groups=x.shape[1], padding=max(self.left_frames, self.right_frames))
|
257 |
-
if len(or_shape) == 4: cw_x = cw_x.reshape(or_shape[0], cw_x.shape[1], or_shape[2], cw_x.shape[-1])
|
258 |
-
|
259 |
-
return cw_x.transpose(1, 2)
|
260 |
-
|
261 |
-
class FilterProperties:
|
262 |
-
def __init__(self, window_size = 0, stride = 1, dilation = 1, causal = False):
|
263 |
-
self.window_size = window_size
|
264 |
-
self.stride = stride
|
265 |
-
self.dilation = dilation
|
266 |
-
self.causal = causal
|
267 |
-
|
268 |
-
def __post_init__(self):
|
269 |
-
assert self.window_size > 0
|
270 |
-
assert self.stride > 0
|
271 |
-
assert (self.dilation > 0)
|
272 |
-
|
273 |
-
@staticmethod
|
274 |
-
def pointwise_filter():
|
275 |
-
return FilterProperties(window_size=1, stride=1)
|
276 |
-
|
277 |
-
def get_effective_size(self):
|
278 |
-
return 1 + ((self.window_size - 1) * self.dilation)
|
279 |
-
|
280 |
-
def get_convolution_padding(self):
|
281 |
-
if self.window_size % 2 == 0: raise ValueError
|
282 |
-
if self.causal: return self.get_effective_size() - 1
|
283 |
-
|
284 |
-
return (self.get_effective_size() - 1) // 2
|
285 |
-
|
286 |
-
def get_noncausal_equivalent(self):
|
287 |
-
if not self.causal: return self
|
288 |
-
return FilterProperties(window_size=(self.window_size - 1) * 2 + 1, stride=self.stride, dilation=self.dilation, causal=False)
|
289 |
-
|
290 |
-
def with_on_top(self, other, allow_approximate=True):
|
291 |
-
self_size = self.window_size
|
292 |
-
|
293 |
-
if other.window_size % 2 == 0:
|
294 |
-
if allow_approximate: other_size = other.window_size + 1
|
295 |
-
else: raise ValueError
|
296 |
-
else: other_size = other.window_size
|
297 |
-
|
298 |
-
if (self.causal or other.causal) and not (self.causal and other.causal):
|
299 |
-
if allow_approximate: return self.get_noncausal_equivalent().with_on_top(other.get_noncausal_equivalent())
|
300 |
-
else: raise ValueError
|
301 |
-
|
302 |
-
return FilterProperties(self_size + (self.stride * (other_size - 1)), self.stride * other.stride, self.dilation * other.dilation, self.causal)
|
303 |
-
|
304 |
-
class STFT(torch.nn.Module):
|
305 |
-
def __init__(self, sample_rate, win_length=25, hop_length=10, n_fft=400, window_fn=torch.hamming_window, normalized_stft=False, center=True, pad_mode="constant", onesided=True):
|
306 |
-
super().__init__()
|
307 |
-
self.sample_rate = sample_rate
|
308 |
-
self.win_length = win_length
|
309 |
-
self.hop_length = hop_length
|
310 |
-
self.n_fft = n_fft
|
311 |
-
self.normalized_stft = normalized_stft
|
312 |
-
self.center = center
|
313 |
-
self.pad_mode = pad_mode
|
314 |
-
self.onesided = onesided
|
315 |
-
self.win_length = int(round((self.sample_rate / 1000.0) * self.win_length))
|
316 |
-
self.hop_length = int(round((self.sample_rate / 1000.0) * self.hop_length))
|
317 |
-
self.window = window_fn(self.win_length)
|
318 |
-
|
319 |
-
def forward(self, x):
|
320 |
-
or_shape = x.shape
|
321 |
-
if len(or_shape) == 3: x = x.transpose(1, 2).reshape(or_shape[0] * or_shape[2], or_shape[1])
|
322 |
-
|
323 |
-
stft = torch.view_as_real(torch.stft(x, self.n_fft, self.hop_length, self.win_length, self.window.to(x.device), self.center, self.pad_mode, self.normalized_stft, self.onesided, return_complex=True))
|
324 |
-
stft = stft.reshape(or_shape[0], or_shape[2], stft.shape[1], stft.shape[2], stft.shape[3]).permute(0, 3, 2, 4, 1) if len(or_shape) == 3 else stft.transpose(2, 1)
|
325 |
-
|
326 |
-
return stft
|
327 |
-
|
328 |
-
def get_filter_properties(self):
|
329 |
-
if not self.center: raise ValueError
|
330 |
-
return FilterProperties(window_size=self.win_length, stride=self.hop_length)
|
331 |
-
|
332 |
-
class Deltas(torch.nn.Module):
|
333 |
-
def __init__(self, input_size, window_length=5):
|
334 |
-
super().__init__()
|
335 |
-
self.n = (window_length - 1) // 2
|
336 |
-
self.denom = self.n * (self.n + 1) * (2 * self.n + 1) / 3
|
337 |
-
self.register_buffer("kernel", torch.arange(-self.n, self.n + 1, dtype=torch.float32).repeat(input_size, 1, 1),)
|
338 |
-
|
339 |
-
def forward(self, x):
|
340 |
-
x = x.transpose(1, 2).transpose(2, -1)
|
341 |
-
or_shape = x.shape
|
342 |
-
|
343 |
-
if len(or_shape) == 4: x = x.reshape(or_shape[0] * or_shape[2], or_shape[1], or_shape[3])
|
344 |
-
|
345 |
-
x = torch.nn.functional.pad(x, (self.n, self.n), mode="replicate")
|
346 |
-
delta_coeff = (torch.nn.functional.conv1d(x, self.kernel.to(x.device), groups=x.shape[1]) / self.denom)
|
347 |
-
|
348 |
-
if len(or_shape) == 4: delta_coeff = delta_coeff.reshape(or_shape[0], or_shape[1], or_shape[2], or_shape[3])
|
349 |
-
return delta_coeff.transpose(1, -1).transpose(2, -1)
|
350 |
-
|
351 |
-
class Fbank(torch.nn.Module):
|
352 |
-
def __init__(self, deltas=False, context=False, requires_grad=False, sample_rate=16000, f_min=0, f_max=None, n_fft=400, n_mels=40, filter_shape="triangular", param_change_factor=1.0, param_rand_factor=0.0, left_frames=5, right_frames=5, win_length=25, hop_length=10):
|
353 |
-
super().__init__()
|
354 |
-
self.deltas = deltas
|
355 |
-
self.context = context
|
356 |
-
self.requires_grad = requires_grad
|
357 |
-
if f_max is None: f_max = sample_rate / 2
|
358 |
-
self.compute_STFT = STFT(sample_rate=sample_rate,n_fft=n_fft,win_length=win_length,hop_length=hop_length)
|
359 |
-
self.compute_fbanks = Filterbank(sample_rate=sample_rate,n_fft=n_fft,n_mels=n_mels,f_min=f_min,f_max=f_max,freeze=not requires_grad,filter_shape=filter_shape,param_change_factor=param_change_factor,param_rand_factor=param_rand_factor)
|
360 |
-
self.compute_deltas = Deltas(input_size=n_mels)
|
361 |
-
self.context_window = ContextWindow(left_frames=left_frames, right_frames=right_frames)
|
362 |
-
|
363 |
-
@fwd_default_precision(cast_inputs=torch.float32)
|
364 |
-
def forward(self, wav):
|
365 |
-
fbanks = self.compute_fbanks(spectral_magnitude(self.compute_STFT(wav)))
|
366 |
-
if self.deltas:
|
367 |
-
delta1 = self.compute_deltas(fbanks)
|
368 |
-
fbanks = torch.cat([fbanks, delta1, self.compute_deltas(delta1)], dim=2)
|
369 |
-
|
370 |
-
if self.context: fbanks = self.context_window(fbanks)
|
371 |
-
return fbanks
|
372 |
-
|
373 |
-
def get_filter_properties(self):
|
374 |
-
return self.compute_STFT.get_filter_properties()
|
375 |
-
|
376 |
-
@register_checkpoint_hooks
|
377 |
-
class InputNormalization(torch.nn.Module):
|
378 |
-
def __init__(self, mean_norm=True, std_norm=True, norm_type="global", avg_factor=None, requires_grad=False, update_until_epoch=3):
|
379 |
-
super().__init__()
|
380 |
-
self.mean_norm = mean_norm
|
381 |
-
self.std_norm = std_norm
|
382 |
-
self.norm_type = norm_type
|
383 |
-
self.avg_factor = avg_factor
|
384 |
-
self.requires_grad = requires_grad
|
385 |
-
self.glob_mean = torch.tensor([0])
|
386 |
-
self.glob_std = torch.tensor([0])
|
387 |
-
self.spk_dict_mean = {}
|
388 |
-
self.spk_dict_std = {}
|
389 |
-
self.spk_dict_count = {}
|
390 |
-
self.weight = 1.0
|
391 |
-
self.count = 0
|
392 |
-
self.eps = 1e-10
|
393 |
-
self.update_until_epoch = update_until_epoch
|
394 |
-
|
395 |
-
def forward(self, x, lengths, spk_ids = torch.tensor([]), epoch=0):
|
396 |
-
N_batches = x.shape[0]
|
397 |
-
current_means, current_stds = [], []
|
398 |
-
|
399 |
-
if self.norm_type == "sentence" or self.norm_type == "speaker": out = torch.empty_like(x)
|
400 |
-
|
401 |
-
for snt_id in range(N_batches):
|
402 |
-
actual_size = torch.round(lengths[snt_id] * x.shape[1]).int()
|
403 |
-
current_mean, current_std = self._compute_current_stats(x[snt_id, 0:actual_size, ...])
|
404 |
-
|
405 |
-
current_means.append(current_mean)
|
406 |
-
current_stds.append(current_std)
|
407 |
-
|
408 |
-
if self.norm_type == "sentence": out[snt_id] = (x[snt_id] - current_mean.data) / current_std.data
|
409 |
-
|
410 |
-
if self.norm_type == "speaker":
|
411 |
-
spk_id = int(spk_ids[snt_id][0])
|
412 |
-
|
413 |
-
if self.training:
|
414 |
-
if spk_id not in self.spk_dict_mean:
|
415 |
-
self.spk_dict_mean[spk_id] = current_mean
|
416 |
-
self.spk_dict_std[spk_id] = current_std
|
417 |
-
self.spk_dict_count[spk_id] = 1
|
418 |
-
else:
|
419 |
-
self.spk_dict_count[spk_id] = (self.spk_dict_count[spk_id] + 1)
|
420 |
-
self.weight = (1 / self.spk_dict_count[spk_id]) if self.avg_factor is None else self.avg_factor
|
421 |
-
|
422 |
-
self.spk_dict_mean[spk_id] = (1 - self.weight) * self.spk_dict_mean[spk_id].to(current_mean) + self.weight * current_mean
|
423 |
-
self.spk_dict_std[spk_id] = (1 - self.weight) * self.spk_dict_std[spk_id].to(current_std) + self.weight * current_std
|
424 |
-
|
425 |
-
self.spk_dict_mean[spk_id].detach()
|
426 |
-
self.spk_dict_std[spk_id].detach()
|
427 |
-
|
428 |
-
speaker_mean = self.spk_dict_mean[spk_id].data
|
429 |
-
speaker_std = self.spk_dict_std[spk_id].data
|
430 |
-
else:
|
431 |
-
if spk_id in self.spk_dict_mean:
|
432 |
-
speaker_mean = self.spk_dict_mean[spk_id].data
|
433 |
-
speaker_std = self.spk_dict_std[spk_id].data
|
434 |
-
else:
|
435 |
-
speaker_mean = current_mean.data
|
436 |
-
speaker_std = current_std.data
|
437 |
-
|
438 |
-
out[snt_id] = (x[snt_id] - speaker_mean) / speaker_std
|
439 |
-
|
440 |
-
if self.norm_type == "batch" or self.norm_type == "global":
|
441 |
-
current_mean = ddp_all_reduce(torch.mean(torch.stack(current_means), dim=0), torch.distributed.ReduceOp.AVG)
|
442 |
-
current_std = ddp_all_reduce(torch.mean(torch.stack(current_stds), dim=0), torch.distributed.ReduceOp.AVG)
|
443 |
-
|
444 |
-
if self.norm_type == "batch": out = (x - current_mean.data) / (current_std.data)
|
445 |
-
|
446 |
-
if self.norm_type == "global":
|
447 |
-
if self.training:
|
448 |
-
if self.count == 0:
|
449 |
-
self.glob_mean = current_mean
|
450 |
-
self.glob_std = current_std
|
451 |
-
elif epoch is None or epoch < self.update_until_epoch:
|
452 |
-
self.weight = (1 / (self.count + 1)) if self.avg_factor is None else self.avg_factor
|
453 |
-
self.glob_mean = (1 - self.weight) * self.glob_mean.to(current_mean) + self.weight * current_mean
|
454 |
-
self.glob_std = (1 - self.weight) * self.glob_std.to(current_std) + self.weight * current_std
|
455 |
-
|
456 |
-
self.glob_mean.detach()
|
457 |
-
self.glob_std.detach()
|
458 |
-
self.count = self.count + 1
|
459 |
-
|
460 |
-
out = (x - self.glob_mean.data.to(x)) / (self.glob_std.data.to(x))
|
461 |
-
|
462 |
-
return out
|
463 |
-
|
464 |
-
def _compute_current_stats(self, x):
|
465 |
-
current_std = torch.std(x, dim=0).detach().data if self.std_norm else torch.tensor([1.0], device=x.device)
|
466 |
-
return torch.mean(x, dim=0).detach().data if self.mean_norm else torch.tensor([0.0], device=x.device), torch.max(current_std, self.eps * torch.ones_like(current_std))
|
467 |
-
|
468 |
-
def _statistics_dict(self):
|
469 |
-
state = {}
|
470 |
-
state["count"] = self.count
|
471 |
-
state["glob_mean"] = self.glob_mean
|
472 |
-
state["glob_std"] = self.glob_std
|
473 |
-
state["spk_dict_mean"] = self.spk_dict_mean
|
474 |
-
state["spk_dict_std"] = self.spk_dict_std
|
475 |
-
state["spk_dict_count"] = self.spk_dict_count
|
476 |
-
|
477 |
-
return state
|
478 |
-
|
479 |
-
def _load_statistics_dict(self, state):
|
480 |
-
self.count = state["count"]
|
481 |
-
|
482 |
-
if isinstance(state["glob_mean"], int):
|
483 |
-
self.glob_mean = state["glob_mean"]
|
484 |
-
self.glob_std = state["glob_std"]
|
485 |
-
else:
|
486 |
-
self.glob_mean = state["glob_mean"]
|
487 |
-
self.glob_std = state["glob_std"]
|
488 |
-
|
489 |
-
self.spk_dict_mean = {}
|
490 |
-
for spk in state["spk_dict_mean"]:
|
491 |
-
self.spk_dict_mean[spk] = state["spk_dict_mean"][spk]
|
492 |
-
|
493 |
-
self.spk_dict_std = {}
|
494 |
-
for spk in state["spk_dict_std"]:
|
495 |
-
self.spk_dict_std[spk] = state["spk_dict_std"][spk]
|
496 |
-
|
497 |
-
self.spk_dict_count = state["spk_dict_count"]
|
498 |
-
return state
|
499 |
-
|
500 |
-
def to(self, device):
|
501 |
-
self = super(InputNormalization, self).to(device)
|
502 |
-
self.glob_mean = self.glob_mean.to(device)
|
503 |
-
self.glob_std = self.glob_std.to(device)
|
504 |
-
|
505 |
-
for spk in self.spk_dict_mean:
|
506 |
-
self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device)
|
507 |
-
self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device)
|
508 |
-
|
509 |
-
return self
|
510 |
-
|
511 |
-
@mark_as_saver
|
512 |
-
def _save(self, path):
|
513 |
-
torch.save(self._statistics_dict(), path)
|
514 |
-
|
515 |
-
@mark_as_transfer
|
516 |
-
@mark_as_loader
|
517 |
-
def _load(self, path, end_of_epoch=False):
|
518 |
-
del end_of_epoch
|
519 |
-
stats = torch.load(path, map_location="cpu")
|
520 |
-
self._load_statistics_dict(stats)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/speaker_diarization/parameter_transfer.py
DELETED
@@ -1,120 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import inspect
|
4 |
-
|
5 |
-
sys.path.append(os.getcwd())
|
6 |
-
|
7 |
-
from main.library.speaker_diarization.speechbrain import fetch, run_on_main
|
8 |
-
from main.library.speaker_diarization.features import DEFAULT_TRANSFER_HOOKS, DEFAULT_LOAD_HOOKS
|
9 |
-
|
10 |
-
|
11 |
-
def get_default_hook(obj, default_hooks):
|
12 |
-
for cls in inspect.getmro(type(obj)):
|
13 |
-
if cls in default_hooks: return default_hooks[cls]
|
14 |
-
|
15 |
-
return None
|
16 |
-
|
17 |
-
class Pretrainer:
|
18 |
-
def __init__(self, loadables=None, paths=None, custom_hooks=None, conditions=None):
|
19 |
-
self.loadables = {}
|
20 |
-
|
21 |
-
if loadables is not None: self.add_loadables(loadables)
|
22 |
-
self.paths = {}
|
23 |
-
|
24 |
-
if paths is not None: self.add_paths(paths)
|
25 |
-
self.custom_hooks = {}
|
26 |
-
|
27 |
-
if custom_hooks is not None: self.add_custom_hooks(custom_hooks)
|
28 |
-
self.conditions = {}
|
29 |
-
|
30 |
-
if conditions is not None: self.add_conditions(conditions)
|
31 |
-
self.is_local = []
|
32 |
-
|
33 |
-
def add_loadables(self, loadables):
|
34 |
-
self.loadables.update(loadables)
|
35 |
-
|
36 |
-
def add_paths(self, paths):
|
37 |
-
self.paths.update(paths)
|
38 |
-
|
39 |
-
def add_custom_hooks(self, custom_hooks):
|
40 |
-
self.custom_hooks.update(custom_hooks)
|
41 |
-
|
42 |
-
def add_conditions(self, conditions):
|
43 |
-
self.conditions.update(conditions)
|
44 |
-
|
45 |
-
@staticmethod
|
46 |
-
def split_path(path):
|
47 |
-
def split(src):
|
48 |
-
if "/" in src: return src.rsplit("/", maxsplit=1)
|
49 |
-
else: return "./", src
|
50 |
-
|
51 |
-
return split(path)
|
52 |
-
|
53 |
-
def collect_files(self, default_source=None):
|
54 |
-
loadable_paths = {}
|
55 |
-
for name in self.loadables:
|
56 |
-
if not self.is_loadable(name): continue
|
57 |
-
save_filename = name + ".ckpt"
|
58 |
-
|
59 |
-
if name in self.paths: source, filename = self.split_path(self.paths[name])
|
60 |
-
elif default_source is not None:
|
61 |
-
filename = save_filename
|
62 |
-
source = default_source
|
63 |
-
else: raise ValueError
|
64 |
-
|
65 |
-
fetch_kwargs = {"filename": filename, "source": source}
|
66 |
-
path = None
|
67 |
-
|
68 |
-
def run_fetch(**kwargs):
|
69 |
-
nonlocal path
|
70 |
-
|
71 |
-
path = fetch(**kwargs)
|
72 |
-
|
73 |
-
run_on_main(run_fetch, kwargs=fetch_kwargs, post_func=run_fetch, post_kwargs=fetch_kwargs)
|
74 |
-
|
75 |
-
loadable_paths[name] = path
|
76 |
-
self.paths[name] = str(path)
|
77 |
-
self.is_local.append(name)
|
78 |
-
|
79 |
-
return loadable_paths
|
80 |
-
|
81 |
-
def is_loadable(self, name):
|
82 |
-
if name not in self.conditions: return True
|
83 |
-
condition = self.conditions[name]
|
84 |
-
|
85 |
-
if callable(condition): return condition()
|
86 |
-
else: return bool(condition)
|
87 |
-
|
88 |
-
def load_collected(self):
|
89 |
-
paramfiles = {}
|
90 |
-
for name in self.loadables:
|
91 |
-
if not self.is_loadable(name): continue
|
92 |
-
|
93 |
-
if name in self.is_local: paramfiles[name] = self.paths[name]
|
94 |
-
else: raise ValueError
|
95 |
-
|
96 |
-
self._call_load_hooks(paramfiles)
|
97 |
-
|
98 |
-
def _call_load_hooks(self, paramfiles):
|
99 |
-
for name, obj in self.loadables.items():
|
100 |
-
if not self.is_loadable(name): continue
|
101 |
-
loadpath = paramfiles[name]
|
102 |
-
|
103 |
-
if name in self.custom_hooks:
|
104 |
-
self.custom_hooks[name](obj, loadpath)
|
105 |
-
continue
|
106 |
-
|
107 |
-
default_hook = get_default_hook(obj, DEFAULT_TRANSFER_HOOKS)
|
108 |
-
|
109 |
-
if default_hook is not None:
|
110 |
-
default_hook(obj, loadpath)
|
111 |
-
continue
|
112 |
-
|
113 |
-
default_hook = get_default_hook(obj, DEFAULT_LOAD_HOOKS)
|
114 |
-
|
115 |
-
if default_hook is not None:
|
116 |
-
end_of_epoch = False
|
117 |
-
default_hook(obj, loadpath, end_of_epoch)
|
118 |
-
continue
|
119 |
-
|
120 |
-
raise RuntimeError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/speaker_diarization/segment.py
DELETED
@@ -1,540 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
|
3 |
-
from sortedcontainers import SortedList
|
4 |
-
|
5 |
-
PYANNOTE_SEGMENT = 'segment'
|
6 |
-
|
7 |
-
|
8 |
-
class Timeline:
|
9 |
-
@classmethod
|
10 |
-
def from_df(cls, df, uri = None):
|
11 |
-
return cls(segments=list(df[PYANNOTE_SEGMENT]), uri=uri)
|
12 |
-
|
13 |
-
def __init__(self, segments = None, uri = None):
|
14 |
-
if segments is None: segments = ()
|
15 |
-
segments_set = set([segment for segment in segments if segment])
|
16 |
-
|
17 |
-
self.segments_set_ = segments_set
|
18 |
-
self.segments_list_ = SortedList(segments_set)
|
19 |
-
self.segments_boundaries_ = SortedList((boundary for segment in segments_set for boundary in segment))
|
20 |
-
self.uri = uri
|
21 |
-
|
22 |
-
def __len__(self):
|
23 |
-
return len(self.segments_set_)
|
24 |
-
|
25 |
-
def __nonzero__(self):
|
26 |
-
return self.__bool__()
|
27 |
-
|
28 |
-
def __bool__(self):
|
29 |
-
return len(self.segments_set_) > 0
|
30 |
-
|
31 |
-
def __iter__(self):
|
32 |
-
return iter(self.segments_list_)
|
33 |
-
|
34 |
-
def __getitem__(self, k):
|
35 |
-
return self.segments_list_[k]
|
36 |
-
|
37 |
-
def __eq__(self, other):
|
38 |
-
return self.segments_set_ == other.segments_set_
|
39 |
-
|
40 |
-
def __ne__(self, other):
|
41 |
-
return self.segments_set_ != other.segments_set_
|
42 |
-
|
43 |
-
def index(self, segment):
|
44 |
-
return self.segments_list_.index(segment)
|
45 |
-
|
46 |
-
def add(self, segment):
|
47 |
-
segments_set_ = self.segments_set_
|
48 |
-
if segment in segments_set_ or not segment: return self
|
49 |
-
|
50 |
-
segments_set_.add(segment)
|
51 |
-
self.segments_list_.add(segment)
|
52 |
-
|
53 |
-
segments_boundaries_ = self.segments_boundaries_
|
54 |
-
segments_boundaries_.add(segment.start)
|
55 |
-
segments_boundaries_.add(segment.end)
|
56 |
-
|
57 |
-
return self
|
58 |
-
|
59 |
-
def remove(self, segment):
|
60 |
-
segments_set_ = self.segments_set_
|
61 |
-
if segment not in segments_set_: return self
|
62 |
-
|
63 |
-
segments_set_.remove(segment)
|
64 |
-
self.segments_list_.remove(segment)
|
65 |
-
|
66 |
-
segments_boundaries_ = self.segments_boundaries_
|
67 |
-
segments_boundaries_.remove(segment.start)
|
68 |
-
segments_boundaries_.remove(segment.end)
|
69 |
-
|
70 |
-
return self
|
71 |
-
|
72 |
-
def discard(self, segment):
|
73 |
-
return self.remove(segment)
|
74 |
-
|
75 |
-
def __ior__(self, timeline):
|
76 |
-
return self.update(timeline)
|
77 |
-
|
78 |
-
def update(self, timeline):
|
79 |
-
segments_set = self.segments_set_
|
80 |
-
segments_set |= timeline.segments_set_
|
81 |
-
|
82 |
-
self.segments_list_ = SortedList(segments_set)
|
83 |
-
self.segments_boundaries_ = SortedList((boundary for segment in segments_set for boundary in segment))
|
84 |
-
|
85 |
-
return self
|
86 |
-
|
87 |
-
def __or__(self, timeline):
|
88 |
-
return self.union(timeline)
|
89 |
-
|
90 |
-
def union(self, timeline):
|
91 |
-
return Timeline(segments=self.segments_set_ | timeline.segments_set_, uri=self.uri)
|
92 |
-
|
93 |
-
def co_iter(self, other):
|
94 |
-
for segment in self.segments_list_:
|
95 |
-
temp = Segment(start=segment.end, end=segment.end)
|
96 |
-
|
97 |
-
for other_segment in other.segments_list_.irange(maximum=temp):
|
98 |
-
if segment.intersects(other_segment): yield segment, other_segment
|
99 |
-
|
100 |
-
def crop_iter(self, support, mode = 'intersection', returns_mapping = False):
|
101 |
-
if mode not in {'loose', 'strict', 'intersection'}: raise ValueError
|
102 |
-
if not isinstance(support, (Segment, Timeline)): raise TypeError
|
103 |
-
|
104 |
-
if isinstance(support, Segment):
|
105 |
-
support = Timeline(segments=([support] if support else []), uri=self.uri)
|
106 |
-
|
107 |
-
for yielded in self.crop_iter(support, mode=mode, returns_mapping=returns_mapping):
|
108 |
-
yield yielded
|
109 |
-
|
110 |
-
return
|
111 |
-
|
112 |
-
support = support.support()
|
113 |
-
|
114 |
-
if mode == 'loose':
|
115 |
-
for segment, _ in self.co_iter(support):
|
116 |
-
yield segment
|
117 |
-
|
118 |
-
return
|
119 |
-
|
120 |
-
if mode == 'strict':
|
121 |
-
for segment, other_segment in self.co_iter(support):
|
122 |
-
if segment in other_segment: yield segment
|
123 |
-
|
124 |
-
return
|
125 |
-
|
126 |
-
for segment, other_segment in self.co_iter(support):
|
127 |
-
mapped_to = segment & other_segment
|
128 |
-
if not mapped_to: continue
|
129 |
-
|
130 |
-
if returns_mapping: yield segment, mapped_to
|
131 |
-
else: yield mapped_to
|
132 |
-
|
133 |
-
def crop(self, support, mode = 'intersection', returns_mapping = False):
|
134 |
-
if mode == 'intersection' and returns_mapping:
|
135 |
-
segments, mapping = [], {}
|
136 |
-
|
137 |
-
for segment, mapped_to in self.crop_iter(support, mode='intersection', returns_mapping=True):
|
138 |
-
segments.append(mapped_to)
|
139 |
-
mapping[mapped_to] = mapping.get(mapped_to, list()) + [segment]
|
140 |
-
|
141 |
-
return Timeline(segments=segments, uri=self.uri), mapping
|
142 |
-
|
143 |
-
return Timeline(segments=self.crop_iter(support, mode=mode), uri=self.uri)
|
144 |
-
|
145 |
-
def overlapping(self, t):
|
146 |
-
return list(self.overlapping_iter(t))
|
147 |
-
|
148 |
-
def overlapping_iter(self, t):
|
149 |
-
for segment in self.segments_list_.irange(maximum=Segment(start=t, end=t)):
|
150 |
-
if segment.overlaps(t): yield segment
|
151 |
-
|
152 |
-
def get_overlap(self):
|
153 |
-
overlaps_tl = Timeline(uri=self.uri)
|
154 |
-
|
155 |
-
for s1, s2 in self.co_iter(self):
|
156 |
-
if s1 == s2: continue
|
157 |
-
|
158 |
-
overlaps_tl.add(s1 & s2)
|
159 |
-
|
160 |
-
return overlaps_tl.support()
|
161 |
-
|
162 |
-
def extrude(self, removed, mode = 'intersection'):
|
163 |
-
if isinstance(removed, Segment): removed = Timeline([removed])
|
164 |
-
|
165 |
-
if mode == "loose": mode = "strict"
|
166 |
-
elif mode == "strict": mode = "loose"
|
167 |
-
|
168 |
-
return self.crop(removed.gaps(support=Timeline([self.extent()], uri=self.uri)), mode=mode)
|
169 |
-
|
170 |
-
def __str__(self):
|
171 |
-
n = len(self.segments_list_)
|
172 |
-
string = "["
|
173 |
-
|
174 |
-
for i, segment in enumerate(self.segments_list_):
|
175 |
-
string += str(segment)
|
176 |
-
string += "\n " if i + 1 < n else ""
|
177 |
-
|
178 |
-
string += "]"
|
179 |
-
return string
|
180 |
-
|
181 |
-
def __repr__(self):
|
182 |
-
return "<Timeline(uri=%s, segments=%s)>" % (self.uri, list(self.segments_list_))
|
183 |
-
|
184 |
-
def __contains__(self, included):
|
185 |
-
if isinstance(included, Segment): return included in self.segments_set_
|
186 |
-
elif isinstance(included, Timeline): return self.segments_set_.issuperset(included.segments_set_)
|
187 |
-
else: raise TypeError
|
188 |
-
|
189 |
-
def empty(self):
|
190 |
-
return Timeline(uri=self.uri)
|
191 |
-
|
192 |
-
def covers(self, other):
|
193 |
-
gaps = self.gaps(support=other.extent())
|
194 |
-
|
195 |
-
for _ in gaps.co_iter(other):
|
196 |
-
return False
|
197 |
-
|
198 |
-
return True
|
199 |
-
|
200 |
-
def copy(self, segment_func = None):
|
201 |
-
if segment_func is None: return Timeline(segments=self.segments_list_, uri=self.uri)
|
202 |
-
return Timeline(segments=[segment_func(s) for s in self.segments_list_], uri=self.uri)
|
203 |
-
|
204 |
-
def extent(self):
|
205 |
-
if self.segments_set_:
|
206 |
-
segments_boundaries_ = self.segments_boundaries_
|
207 |
-
return Segment(start=segments_boundaries_[0], end=segments_boundaries_[-1])
|
208 |
-
|
209 |
-
return Segment(start=0.0, end=0.0)
|
210 |
-
|
211 |
-
def support_iter(self, collar = 0.0):
|
212 |
-
if not self: return
|
213 |
-
|
214 |
-
new_segment = self.segments_list_[0]
|
215 |
-
|
216 |
-
for segment in self:
|
217 |
-
possible_gap = segment ^ new_segment
|
218 |
-
|
219 |
-
if not possible_gap or possible_gap.duration < collar: new_segment |= segment
|
220 |
-
else:
|
221 |
-
yield new_segment
|
222 |
-
new_segment = segment
|
223 |
-
|
224 |
-
yield new_segment
|
225 |
-
|
226 |
-
def support(self, collar = 0.):
|
227 |
-
return Timeline(segments=self.support_iter(collar), uri=self.uri)
|
228 |
-
|
229 |
-
def duration(self):
|
230 |
-
return sum(s.duration for s in self.support_iter())
|
231 |
-
|
232 |
-
def gaps_iter(self, support = None):
|
233 |
-
if support is None: support = self.extent()
|
234 |
-
if not isinstance(support, (Segment, Timeline)): raise TypeError
|
235 |
-
|
236 |
-
if isinstance(support, Segment):
|
237 |
-
end = support.start
|
238 |
-
|
239 |
-
for segment in self.crop(support, mode='intersection').support():
|
240 |
-
gap = Segment(start=end, end=segment.start)
|
241 |
-
if gap: yield gap
|
242 |
-
|
243 |
-
end = segment.end
|
244 |
-
|
245 |
-
gap = Segment(start=end, end=support.end)
|
246 |
-
if gap: yield gap
|
247 |
-
elif isinstance(support, Timeline):
|
248 |
-
for segment in support.support():
|
249 |
-
for gap in self.gaps_iter(support=segment):
|
250 |
-
yield gap
|
251 |
-
|
252 |
-
def gaps(self, support = None):
|
253 |
-
return Timeline(segments=self.gaps_iter(support=support), uri=self.uri)
|
254 |
-
|
255 |
-
def segmentation(self):
|
256 |
-
support = self.support()
|
257 |
-
timestamps = set([])
|
258 |
-
|
259 |
-
for (start, end) in self:
|
260 |
-
timestamps.add(start)
|
261 |
-
timestamps.add(end)
|
262 |
-
|
263 |
-
timestamps = sorted(timestamps)
|
264 |
-
if len(timestamps) == 0: return Timeline(uri=self.uri)
|
265 |
-
|
266 |
-
segments = []
|
267 |
-
start = timestamps[0]
|
268 |
-
|
269 |
-
for end in timestamps[1:]:
|
270 |
-
segment = Segment(start=start, end=end)
|
271 |
-
|
272 |
-
if segment and support.overlapping(segment.middle): segments.append(segment)
|
273 |
-
start = end
|
274 |
-
|
275 |
-
return Timeline(segments=segments, uri=self.uri)
|
276 |
-
|
277 |
-
def _iter_uem(self):
|
278 |
-
uri = self.uri if self.uri else "<NA>"
|
279 |
-
|
280 |
-
for segment in self:
|
281 |
-
yield f"{uri} 1 {segment.start:.3f} {segment.end:.3f}\n"
|
282 |
-
|
283 |
-
def to_uem(self):
|
284 |
-
return "".join([line for line in self._iter_uem()])
|
285 |
-
|
286 |
-
def write_uem(self, file):
|
287 |
-
for line in self._iter_uem():
|
288 |
-
file.write(line)
|
289 |
-
|
290 |
-
def _repr_png_(self):
|
291 |
-
return None
|
292 |
-
|
293 |
-
class Segment:
|
294 |
-
def __init__(self, start, end):
|
295 |
-
self.start = start
|
296 |
-
self.end = end
|
297 |
-
|
298 |
-
@staticmethod
|
299 |
-
def set_precision(ndigits = None):
|
300 |
-
global AUTO_ROUND_TIME, SEGMENT_PRECISION
|
301 |
-
|
302 |
-
if ndigits is None:
|
303 |
-
AUTO_ROUND_TIME = False
|
304 |
-
SEGMENT_PRECISION = 1e-6
|
305 |
-
else:
|
306 |
-
AUTO_ROUND_TIME = True
|
307 |
-
SEGMENT_PRECISION = 10 ** (-ndigits)
|
308 |
-
|
309 |
-
def __bool__(self):
|
310 |
-
return bool((self.end - self.start) > SEGMENT_PRECISION)
|
311 |
-
|
312 |
-
def __post_init__(self):
|
313 |
-
if AUTO_ROUND_TIME:
|
314 |
-
object.__setattr__(self, 'start', int(self.start / SEGMENT_PRECISION + 0.5) * SEGMENT_PRECISION)
|
315 |
-
object.__setattr__(self, 'end', int(self.end / SEGMENT_PRECISION + 0.5) * SEGMENT_PRECISION)
|
316 |
-
|
317 |
-
@property
|
318 |
-
def duration(self):
|
319 |
-
return self.end - self.start if self else 0.
|
320 |
-
|
321 |
-
@property
|
322 |
-
def middle(self):
|
323 |
-
return .5 * (self.start + self.end)
|
324 |
-
|
325 |
-
def __iter__(self):
|
326 |
-
yield self.start
|
327 |
-
yield self.end
|
328 |
-
|
329 |
-
def copy(self):
|
330 |
-
return Segment(start=self.start, end=self.end)
|
331 |
-
|
332 |
-
def __contains__(self, other):
|
333 |
-
return (self.start <= other.start) and (self.end >= other.end)
|
334 |
-
|
335 |
-
def __and__(self, other):
|
336 |
-
return Segment(start=max(self.start, other.start), end=min(self.end, other.end))
|
337 |
-
|
338 |
-
def intersects(self, other):
|
339 |
-
return (self.start < other.start and other.start < self.end - SEGMENT_PRECISION) or (self.start > other.start and self.start < other.end - SEGMENT_PRECISION) or (self.start == other.start)
|
340 |
-
|
341 |
-
def overlaps(self, t):
|
342 |
-
return self.start <= t and self.end >= t
|
343 |
-
|
344 |
-
def __or__(self, other):
|
345 |
-
if not self: return other
|
346 |
-
if not other: return self
|
347 |
-
|
348 |
-
return Segment(start=min(self.start, other.start), end=max(self.end, other.end))
|
349 |
-
|
350 |
-
def __xor__(self, other):
|
351 |
-
if (not self) or (not other): raise ValueError
|
352 |
-
|
353 |
-
return Segment(start=min(self.end, other.end), end=max(self.start, other.start))
|
354 |
-
|
355 |
-
def _str_helper(self, seconds):
|
356 |
-
from datetime import timedelta
|
357 |
-
|
358 |
-
negative = seconds < 0
|
359 |
-
td = timedelta(seconds=abs(seconds))
|
360 |
-
|
361 |
-
hours, remainder = divmod(td.seconds + 86400 * td.days, 3600)
|
362 |
-
minutes, seconds = divmod(remainder, 60)
|
363 |
-
|
364 |
-
return '%s%02d:%02d:%02d.%03d' % ('-' if negative else ' ', hours, minutes, seconds, td.microseconds / 1000)
|
365 |
-
|
366 |
-
def __str__(self):
|
367 |
-
if self: return '[%s --> %s]' % (self._str_helper(self.start), self._str_helper(self.end))
|
368 |
-
return '[]'
|
369 |
-
|
370 |
-
def __repr__(self):
|
371 |
-
return '<Segment(%g, %g)>' % (self.start, self.end)
|
372 |
-
|
373 |
-
def _repr_png_(self):
|
374 |
-
return None
|
375 |
-
|
376 |
-
class SlidingWindow:
|
377 |
-
def __init__(self, duration=0.030, step=0.010, start=0.000, end=None):
|
378 |
-
if duration <= 0: raise ValueError
|
379 |
-
self.__duration = duration
|
380 |
-
if step <= 0: raise ValueError
|
381 |
-
|
382 |
-
self.__step = step
|
383 |
-
self.__start = start
|
384 |
-
|
385 |
-
if end is None: self.__end = np.inf
|
386 |
-
else:
|
387 |
-
if end <= start: raise ValueError
|
388 |
-
self.__end = end
|
389 |
-
|
390 |
-
self.__i = -1
|
391 |
-
|
392 |
-
@property
|
393 |
-
def start(self):
|
394 |
-
return self.__start
|
395 |
-
|
396 |
-
@property
|
397 |
-
def end(self):
|
398 |
-
return self.__end
|
399 |
-
|
400 |
-
@property
|
401 |
-
def step(self):
|
402 |
-
return self.__step
|
403 |
-
|
404 |
-
@property
|
405 |
-
def duration(self):
|
406 |
-
return self.__duration
|
407 |
-
|
408 |
-
def closest_frame(self, t):
|
409 |
-
return int(np.rint((t - self.__start - .5 * self.__duration) / self.__step))
|
410 |
-
|
411 |
-
def samples(self, from_duration, mode = 'strict'):
|
412 |
-
if mode == 'strict': return int(np.floor((from_duration - self.duration) / self.step)) + 1
|
413 |
-
elif mode == 'loose': return int(np.floor((from_duration + self.duration) / self.step))
|
414 |
-
elif mode == 'center': return int(np.rint((from_duration / self.step)))
|
415 |
-
|
416 |
-
def crop(self, focus, mode = 'loose', fixed = None, return_ranges = False):
|
417 |
-
if not isinstance(focus, (Segment, Timeline)): raise TypeError
|
418 |
-
|
419 |
-
if isinstance(focus, Timeline):
|
420 |
-
if fixed is not None: raise ValueError
|
421 |
-
|
422 |
-
if return_ranges:
|
423 |
-
ranges = []
|
424 |
-
|
425 |
-
for i, s in enumerate(focus.support()):
|
426 |
-
rng = self.crop(s, mode=mode, fixed=fixed, return_ranges=True)
|
427 |
-
|
428 |
-
if i == 0 or rng[0][0] > ranges[-1][1]: ranges += rng
|
429 |
-
else: ranges[-1][1] = rng[0][1]
|
430 |
-
|
431 |
-
return ranges
|
432 |
-
|
433 |
-
return np.unique(np.hstack([self.crop(s, mode=mode, fixed=fixed, return_ranges=False) for s in focus.support()]))
|
434 |
-
|
435 |
-
if mode == 'loose':
|
436 |
-
i = int(np.ceil((focus.start - self.duration - self.start) / self.step))
|
437 |
-
|
438 |
-
if fixed is None:
|
439 |
-
j = int(np.floor((focus.end - self.start) / self.step))
|
440 |
-
rng = (i, j + 1)
|
441 |
-
else:
|
442 |
-
n = self.samples(fixed, mode='loose')
|
443 |
-
rng = (i, i + n)
|
444 |
-
elif mode == 'strict':
|
445 |
-
i = int(np.ceil((focus.start - self.start) / self.step))
|
446 |
-
|
447 |
-
if fixed is None:
|
448 |
-
j = int(np.floor((focus.end - self.duration - self.start) / self.step))
|
449 |
-
rng = (i, j + 1)
|
450 |
-
else:
|
451 |
-
n = self.samples(fixed, mode='strict')
|
452 |
-
rng = (i, i + n)
|
453 |
-
elif mode == 'center':
|
454 |
-
i = self.closest_frame(focus.start)
|
455 |
-
|
456 |
-
if fixed is None:
|
457 |
-
j = self.closest_frame(focus.end)
|
458 |
-
rng = (i, j + 1)
|
459 |
-
else:
|
460 |
-
n = self.samples(fixed, mode='center')
|
461 |
-
rng = (i, i + n)
|
462 |
-
else: raise ValueError
|
463 |
-
|
464 |
-
if return_ranges: return [list(rng)]
|
465 |
-
return np.array(range(*rng), dtype=np.int64)
|
466 |
-
|
467 |
-
def segmentToRange(self, segment):
|
468 |
-
return self.segment_to_range(segment)
|
469 |
-
|
470 |
-
def segment_to_range(self, segment):
|
471 |
-
return self.closest_frame(segment.start), int(segment.duration / self.step) + 1
|
472 |
-
|
473 |
-
def rangeToSegment(self, i0, n):
|
474 |
-
return self.range_to_segment(i0, n)
|
475 |
-
|
476 |
-
def range_to_segment(self, i0, n):
|
477 |
-
start = self.__start + (i0 - .5) * self.__step + .5 * self.__duration
|
478 |
-
|
479 |
-
if i0 == 0: start = self.start
|
480 |
-
return Segment(start, start + (n * self.__step))
|
481 |
-
|
482 |
-
def samplesToDuration(self, nSamples):
|
483 |
-
return self.samples_to_duration(nSamples)
|
484 |
-
|
485 |
-
def samples_to_duration(self, n_samples):
|
486 |
-
return self.range_to_segment(0, n_samples).duration
|
487 |
-
|
488 |
-
def durationToSamples(self, duration):
|
489 |
-
return self.duration_to_samples(duration)
|
490 |
-
|
491 |
-
def duration_to_samples(self, duration):
|
492 |
-
return self.segment_to_range(Segment(0, duration))[1]
|
493 |
-
|
494 |
-
def __getitem__(self, i):
|
495 |
-
start = self.__start + i * self.__step
|
496 |
-
if start >= self.__end: return None
|
497 |
-
|
498 |
-
return Segment(start=start, end=start + self.__duration)
|
499 |
-
|
500 |
-
def next(self):
|
501 |
-
return self.__next__()
|
502 |
-
|
503 |
-
def __next__(self):
|
504 |
-
self.__i += 1
|
505 |
-
window = self[self.__i]
|
506 |
-
|
507 |
-
if window: return window
|
508 |
-
else: raise StopIteration()
|
509 |
-
|
510 |
-
def __iter__(self):
|
511 |
-
self.__i = -1
|
512 |
-
return self
|
513 |
-
|
514 |
-
def __len__(self):
|
515 |
-
if np.isinf(self.__end): raise ValueError
|
516 |
-
i = self.closest_frame(self.__end)
|
517 |
-
|
518 |
-
while (self[i]):
|
519 |
-
i += 1
|
520 |
-
|
521 |
-
length = i
|
522 |
-
return length
|
523 |
-
|
524 |
-
def copy(self):
|
525 |
-
return self.__class__(duration=self.duration, step=self.step, start=self.start, end=self.end)
|
526 |
-
|
527 |
-
def __call__(self, support, align_last = False):
|
528 |
-
if isinstance(support, Timeline): segments = support
|
529 |
-
elif isinstance(support, Segment): segments = Timeline(segments=[support])
|
530 |
-
else: raise TypeError
|
531 |
-
|
532 |
-
for segment in segments:
|
533 |
-
if segment.duration < self.duration: continue
|
534 |
-
|
535 |
-
for s in SlidingWindow(duration=self.duration, step=self.step, start=segment.start, end=segment.end):
|
536 |
-
if s in segment:
|
537 |
-
yield s
|
538 |
-
last = s
|
539 |
-
|
540 |
-
if align_last and last.end < segment.end: yield Segment(start=segment.end - self.duration, end=segment.end)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/speaker_diarization/speechbrain.py
DELETED
@@ -1,220 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import torch
|
3 |
-
import torchaudio
|
4 |
-
|
5 |
-
from functools import wraps
|
6 |
-
from types import SimpleNamespace
|
7 |
-
from torch.nn import SyncBatchNorm
|
8 |
-
from hyperpyyaml import load_hyperpyyaml
|
9 |
-
|
10 |
-
from torch.nn import DataParallel as DP
|
11 |
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
12 |
-
|
13 |
-
MAIN_PROC_ONLY = 0
|
14 |
-
|
15 |
-
def fetch(filename, source):
|
16 |
-
return os.path.abspath(os.path.join(source, filename))
|
17 |
-
|
18 |
-
def run_on_main(func, args=None, kwargs=None, post_func=None, post_args=None, post_kwargs=None, run_post_on_main=False):
|
19 |
-
if args is None: args = []
|
20 |
-
if kwargs is None: kwargs = {}
|
21 |
-
if post_args is None: post_args = []
|
22 |
-
if post_kwargs is None: post_kwargs = {}
|
23 |
-
|
24 |
-
main_process_only(func)(*args, **kwargs)
|
25 |
-
ddp_barrier()
|
26 |
-
|
27 |
-
if post_func is not None:
|
28 |
-
if run_post_on_main: post_func(*post_args, **post_kwargs)
|
29 |
-
else:
|
30 |
-
if not if_main_process(): post_func(*post_args, **post_kwargs)
|
31 |
-
ddp_barrier()
|
32 |
-
|
33 |
-
def is_distributed_initialized():
|
34 |
-
return (torch.distributed.is_available() and torch.distributed.is_initialized())
|
35 |
-
|
36 |
-
def if_main_process():
|
37 |
-
if is_distributed_initialized(): return torch.distributed.get_rank() == 0
|
38 |
-
else: return True
|
39 |
-
|
40 |
-
class MainProcessContext:
|
41 |
-
def __enter__(self):
|
42 |
-
global MAIN_PROC_ONLY
|
43 |
-
|
44 |
-
MAIN_PROC_ONLY += 1
|
45 |
-
return self
|
46 |
-
|
47 |
-
def __exit__(self, exc_type, exc_value, traceback):
|
48 |
-
global MAIN_PROC_ONLY
|
49 |
-
|
50 |
-
MAIN_PROC_ONLY -= 1
|
51 |
-
|
52 |
-
def main_process_only(function):
|
53 |
-
@wraps(function)
|
54 |
-
def main_proc_wrapped_func(*args, **kwargs):
|
55 |
-
with MainProcessContext():
|
56 |
-
return function(*args, **kwargs) if if_main_process() else None
|
57 |
-
|
58 |
-
return main_proc_wrapped_func
|
59 |
-
|
60 |
-
def ddp_barrier():
|
61 |
-
if MAIN_PROC_ONLY >= 1 or not is_distributed_initialized(): return
|
62 |
-
|
63 |
-
if torch.distributed.get_backend() == torch.distributed.Backend.NCCL: torch.distributed.barrier(device_ids=[torch.cuda.current_device()])
|
64 |
-
else: torch.distributed.barrier()
|
65 |
-
|
66 |
-
class Resample(torch.nn.Module):
|
67 |
-
def __init__(self, orig_freq=16000, new_freq=16000, *args, **kwargs):
|
68 |
-
super().__init__()
|
69 |
-
|
70 |
-
self.orig_freq = orig_freq
|
71 |
-
self.new_freq = new_freq
|
72 |
-
self.resampler = torchaudio.transforms.Resample(orig_freq=orig_freq, new_freq=new_freq, *args, **kwargs)
|
73 |
-
|
74 |
-
def forward(self, waveforms):
|
75 |
-
if self.orig_freq == self.new_freq: return waveforms
|
76 |
-
|
77 |
-
unsqueezed = False
|
78 |
-
if len(waveforms.shape) == 2:
|
79 |
-
waveforms = waveforms.unsqueeze(1)
|
80 |
-
unsqueezed = True
|
81 |
-
elif len(waveforms.shape) == 3: waveforms = waveforms.transpose(1, 2)
|
82 |
-
else: raise ValueError
|
83 |
-
|
84 |
-
self.resampler.to(waveforms.device)
|
85 |
-
resampled_waveform = self.resampler(waveforms)
|
86 |
-
|
87 |
-
return resampled_waveform.squeeze(1) if unsqueezed else resampled_waveform.transpose(1, 2)
|
88 |
-
|
89 |
-
class AudioNormalizer:
|
90 |
-
def __init__(self, sample_rate=16000, mix="avg-to-mono"):
|
91 |
-
self.sample_rate = sample_rate
|
92 |
-
|
93 |
-
if mix not in ["avg-to-mono", "keep"]: raise ValueError
|
94 |
-
|
95 |
-
self.mix = mix
|
96 |
-
self._cached_resamplers = {}
|
97 |
-
|
98 |
-
def __call__(self, audio, sample_rate):
|
99 |
-
if sample_rate not in self._cached_resamplers: self._cached_resamplers[sample_rate] = Resample(sample_rate, self.sample_rate)
|
100 |
-
return self._mix(self._cached_resamplers[sample_rate](audio.unsqueeze(0)).squeeze(0))
|
101 |
-
|
102 |
-
def _mix(self, audio):
|
103 |
-
flat_input = audio.dim() == 1
|
104 |
-
|
105 |
-
if self.mix == "avg-to-mono":
|
106 |
-
if flat_input: return audio
|
107 |
-
return torch.mean(audio, 1)
|
108 |
-
|
109 |
-
if self.mix == "keep": return audio
|
110 |
-
|
111 |
-
class Pretrained(torch.nn.Module):
|
112 |
-
HPARAMS_NEEDED, MODULES_NEEDED = [], []
|
113 |
-
def __init__(self, modules=None, hparams=None, run_opts=None, freeze_params=True):
|
114 |
-
super().__init__()
|
115 |
-
|
116 |
-
for arg, default in {"device": "cpu", "data_parallel_count": -1, "data_parallel_backend": False, "distributed_launch": False, "distributed_backend": "nccl", "jit": False, "jit_module_keys": None, "compile": False, "compile_module_keys": None, "compile_mode": "reduce-overhead", "compile_using_fullgraph": False, "compile_using_dynamic_shape_tracing": False}.items():
|
117 |
-
if run_opts is not None and arg in run_opts: setattr(self, arg, run_opts[arg])
|
118 |
-
elif hparams is not None and arg in hparams: setattr(self, arg, hparams[arg])
|
119 |
-
else: setattr(self, arg, default)
|
120 |
-
|
121 |
-
self.mods = torch.nn.ModuleDict(modules)
|
122 |
-
|
123 |
-
for module in self.mods.values():
|
124 |
-
if module is not None: module.to(self.device)
|
125 |
-
|
126 |
-
if self.HPARAMS_NEEDED and hparams is None: raise ValueError
|
127 |
-
|
128 |
-
if hparams is not None:
|
129 |
-
for hp in self.HPARAMS_NEEDED:
|
130 |
-
if hp not in hparams: raise ValueError
|
131 |
-
|
132 |
-
self.hparams = SimpleNamespace(**hparams)
|
133 |
-
|
134 |
-
self._prepare_modules(freeze_params)
|
135 |
-
self.audio_normalizer = hparams.get("audio_normalizer", AudioNormalizer())
|
136 |
-
|
137 |
-
def _prepare_modules(self, freeze_params):
|
138 |
-
self._compile()
|
139 |
-
self._wrap_distributed()
|
140 |
-
|
141 |
-
if freeze_params:
|
142 |
-
self.mods.eval()
|
143 |
-
for p in self.mods.parameters():
|
144 |
-
p.requires_grad = False
|
145 |
-
|
146 |
-
def _compile(self):
|
147 |
-
compile_available = hasattr(torch, "compile")
|
148 |
-
if not compile_available and self.compile_module_keys is not None: raise ValueError
|
149 |
-
|
150 |
-
compile_module_keys = set()
|
151 |
-
if self.compile: compile_module_keys = set(self.mods) if self.compile_module_keys is None else set(self.compile_module_keys)
|
152 |
-
|
153 |
-
jit_module_keys = set()
|
154 |
-
if self.jit: jit_module_keys = set(self.mods) if self.jit_module_keys is None else set(self.jit_module_keys)
|
155 |
-
|
156 |
-
for name in compile_module_keys | jit_module_keys:
|
157 |
-
if name not in self.mods: raise ValueError
|
158 |
-
|
159 |
-
for name in compile_module_keys:
|
160 |
-
try:
|
161 |
-
module = torch.compile(self.mods[name], mode=self.compile_mode, fullgraph=self.compile_using_fullgraph, dynamic=self.compile_using_dynamic_shape_tracing)
|
162 |
-
except Exception:
|
163 |
-
continue
|
164 |
-
|
165 |
-
self.mods[name] = module.to(self.device)
|
166 |
-
jit_module_keys.discard(name)
|
167 |
-
|
168 |
-
for name in jit_module_keys:
|
169 |
-
module = torch.jit.script(self.mods[name])
|
170 |
-
self.mods[name] = module.to(self.device)
|
171 |
-
|
172 |
-
def _compile_jit(self):
|
173 |
-
self._compile()
|
174 |
-
|
175 |
-
def _wrap_distributed(self):
|
176 |
-
if not self.distributed_launch and not self.data_parallel_backend: return
|
177 |
-
elif self.distributed_launch:
|
178 |
-
for name, module in self.mods.items():
|
179 |
-
if any(p.requires_grad for p in module.parameters()): self.mods[name] = DDP(SyncBatchNorm.convert_sync_batchnorm(module), device_ids=[self.device])
|
180 |
-
else:
|
181 |
-
for name, module in self.mods.items():
|
182 |
-
if any(p.requires_grad for p in module.parameters()): self.mods[name] = DP(module) if self.data_parallel_count == -1 else DP(module, [i for i in range(self.data_parallel_count)])
|
183 |
-
|
184 |
-
@classmethod
|
185 |
-
def from_hparams(cls, source, hparams_file="hyperparams.yaml", overrides={}, download_only=False, overrides_must_match=True, **kwargs):
|
186 |
-
with open(fetch(filename=hparams_file, source=source)) as fin:
|
187 |
-
hparams = load_hyperpyyaml(fin, overrides, overrides_must_match=overrides_must_match)
|
188 |
-
|
189 |
-
pretrainer = hparams.get("pretrainer", None)
|
190 |
-
|
191 |
-
if pretrainer is not None:
|
192 |
-
run_on_main(pretrainer.collect_files, kwargs={"default_source": source})
|
193 |
-
if not download_only:
|
194 |
-
pretrainer.load_collected()
|
195 |
-
return cls(hparams["modules"], hparams, **kwargs)
|
196 |
-
else: return cls(hparams["modules"], hparams, **kwargs)
|
197 |
-
|
198 |
-
class EncoderClassifier(Pretrained):
|
199 |
-
MODULES_NEEDED = ["compute_features", "mean_var_norm", "embedding_model", "classifier"]
|
200 |
-
|
201 |
-
def encode_batch(self, wavs, wav_lens=None, normalize=False):
|
202 |
-
if len(wavs.shape) == 1: wavs = wavs.unsqueeze(0)
|
203 |
-
if wav_lens is None: wav_lens = torch.ones(wavs.shape[0], device=self.device)
|
204 |
-
|
205 |
-
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
|
206 |
-
wavs = wavs.float()
|
207 |
-
|
208 |
-
embeddings = self.mods.embedding_model(self.mods.mean_var_norm(self.mods.compute_features(wavs), wav_lens), wav_lens)
|
209 |
-
|
210 |
-
if normalize: embeddings = self.hparams.mean_var_norm_emb(embeddings, torch.ones(embeddings.shape[0], device=self.device))
|
211 |
-
return embeddings
|
212 |
-
|
213 |
-
def classify_batch(self, wavs, wav_lens=None):
|
214 |
-
out_prob = self.mods.classifier(self.encode_batch(wavs, wav_lens)).squeeze(1)
|
215 |
-
score, index = torch.max(out_prob, dim=-1)
|
216 |
-
|
217 |
-
return out_prob, score, index, self.hparams.label_encoder.decode_torch(index)
|
218 |
-
|
219 |
-
def forward(self, wavs, wav_lens=None):
|
220 |
-
return self.classify_batch(wavs, wav_lens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/speaker_diarization/whisper.py
DELETED
@@ -1,1290 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import gzip
|
4 |
-
import zlib
|
5 |
-
import tqdm
|
6 |
-
import torch
|
7 |
-
import base64
|
8 |
-
import string
|
9 |
-
import logging
|
10 |
-
import tiktoken
|
11 |
-
import itertools
|
12 |
-
|
13 |
-
import numba as nb
|
14 |
-
import numpy as np
|
15 |
-
import torch.nn as nn
|
16 |
-
import torch.nn.functional as F
|
17 |
-
|
18 |
-
from contextlib import contextmanager
|
19 |
-
from torch.distributions import Categorical
|
20 |
-
from functools import cached_property, lru_cache
|
21 |
-
from dataclasses import dataclass, replace
|
22 |
-
from torch.nn.functional import scaled_dot_product_attention
|
23 |
-
|
24 |
-
sys.path.append(os.getcwd())
|
25 |
-
|
26 |
-
from main.library.utils import load_audio
|
27 |
-
|
28 |
-
LANGUAGES = {"en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish", "pl": "polish", "ca": "catalan", "nl": "dutch", "ar": "arabic", "sv": "swedish", "it": "italian", "id": "indonesian", "hi": "hindi", "fi": "finnish", "vi": "vietnamese", "he": "hebrew", "uk": "ukrainian", "el": "greek", "ms": "malay", "cs": "czech", "ro": "romanian", "da": "danish", "hu": "hungarian", "ta": "tamil", "no": "norwegian", "th": "thai", "ur": "urdu", "hr": "croatian", "bg": "bulgarian", "lt": "lithuanian", "la": "latin", "mi": "maori", "ml": "malayalam", "cy": "welsh", "sk": "slovak", "te": "telugu", "fa": "persian", "lv": "latvian", "bn": "bengali", "sr": "serbian", "az": "azerbaijani", "sl": "slovenian", "kn": "kannada", "et": "estonian", "mk": "macedonian", "br": "breton", "eu": "basque", "is": "icelandic", "hy": "armenian", "ne": "nepali", "mn": "mongolian", "bs": "bosnian", "kk": "kazakh", "sq": "albanian", "sw": "swahili", "gl": "galician", "mr": "marathi", "pa": "punjabi", "si": "sinhala", "km": "khmer", "sn": "shona", "yo": "yoruba", "so": "somali", "af": "afrikaans", "oc": "occitan", "ka": "georgian", "be": "belarusian", "tg": "tajik", "sd": "sindhi", "gu": "gujarati", "am": "amharic", "yi": "yiddish", "lo": "lao", "uz": "uzbek", "fo": "faroese", "ht": "haitian creole", "ps": "pashto", "tk": "turkmen", "nn": "nynorsk", "mt": "maltese", "sa": "sanskrit", "lb": "luxembourgish", "my": "myanmar", "bo": "tibetan", "tl": "tagalog", "mg": "malagasy", "as": "assamese", "tt": "tatar", "haw": "hawaiian", "ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese", "yue": "cantonese"}
|
29 |
-
TO_LANGUAGE_CODE = {**{language: code for code, language in LANGUAGES.items()}, "burmese": "my", "valencian": "ca", "flemish": "nl", "haitian": "ht", "letzeburgesch": "lb", "pushto": "ps", "panjabi": "pa", "moldavian": "ro", "moldovan": "ro", "sinhalese": "si", "castilian": "es", "mandarin": "zh"}
|
30 |
-
_ALIGNMENT_HEADS = {"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO", "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m", "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000", "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj", "large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", "large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`"}
|
31 |
-
|
32 |
-
SAMPLE_RATE, N_FFT, HOP_LENGTH, CHUNK_LENGTH = 16000, 400, 160, 30
|
33 |
-
|
34 |
-
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE
|
35 |
-
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2
|
36 |
-
|
37 |
-
def exact_div(x, y):
|
38 |
-
assert x % y == 0
|
39 |
-
return x // y
|
40 |
-
|
41 |
-
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH)
|
42 |
-
|
43 |
-
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH)
|
44 |
-
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN)
|
45 |
-
|
46 |
-
|
47 |
-
def load_model(name = "base", device = "cpu"):
|
48 |
-
checkpoint_file = os.path.join("assets", "models", "speaker_diarization", "models", name + ".pt")
|
49 |
-
alignment_heads = _ALIGNMENT_HEADS[name]
|
50 |
-
|
51 |
-
with open(checkpoint_file, "rb") as fp:
|
52 |
-
checkpoint = torch.load(fp, map_location=device)
|
53 |
-
|
54 |
-
del checkpoint_file
|
55 |
-
|
56 |
-
model = Whisper(ModelDimensions(**checkpoint["dims"]))
|
57 |
-
model.load_state_dict(checkpoint["model_state_dict"])
|
58 |
-
model.set_alignment_heads(alignment_heads)
|
59 |
-
|
60 |
-
return model.to(device)
|
61 |
-
|
62 |
-
def merge_punctuations(alignment, prepended, appended):
|
63 |
-
i = len(alignment) - 2
|
64 |
-
j = len(alignment) - 1
|
65 |
-
|
66 |
-
while i >= 0:
|
67 |
-
previous = alignment[i]
|
68 |
-
following = alignment[j]
|
69 |
-
|
70 |
-
if previous.word.startswith(" ") and previous.word.strip() in prepended:
|
71 |
-
following.word = previous.word + following.word
|
72 |
-
following.tokens = previous.tokens + following.tokens
|
73 |
-
|
74 |
-
previous.word = ""
|
75 |
-
previous.tokens = []
|
76 |
-
else: j = i
|
77 |
-
|
78 |
-
i -= 1
|
79 |
-
|
80 |
-
i = 0
|
81 |
-
j = 1
|
82 |
-
|
83 |
-
while j < len(alignment):
|
84 |
-
previous = alignment[i]
|
85 |
-
following = alignment[j]
|
86 |
-
|
87 |
-
if not previous.word.endswith(" ") and following.word in appended:
|
88 |
-
previous.word = previous.word + following.word
|
89 |
-
previous.tokens = previous.tokens + following.tokens
|
90 |
-
|
91 |
-
following.word = ""
|
92 |
-
following.tokens = []
|
93 |
-
else: i = j
|
94 |
-
|
95 |
-
j += 1
|
96 |
-
|
97 |
-
class WordTiming:
|
98 |
-
def __init__(self, word, tokens, start, end, probability):
|
99 |
-
self.word = word
|
100 |
-
self.tokens = tokens
|
101 |
-
self.start = start
|
102 |
-
self.end = end
|
103 |
-
self.probability = probability
|
104 |
-
|
105 |
-
@contextmanager
|
106 |
-
def disable_sdpa():
|
107 |
-
prev_state = MultiHeadAttention.use_sdpa
|
108 |
-
try:
|
109 |
-
MultiHeadAttention.use_sdpa = False
|
110 |
-
yield
|
111 |
-
finally:
|
112 |
-
MultiHeadAttention.use_sdpa = prev_state
|
113 |
-
|
114 |
-
def median_filter(x, filter_width):
|
115 |
-
pad_width = filter_width // 2
|
116 |
-
|
117 |
-
if x.shape[-1] <= pad_width: return x
|
118 |
-
if (ndim := x.ndim) <= 2: x = x[None, None, :]
|
119 |
-
|
120 |
-
assert (filter_width > 0 and filter_width % 2 == 1)
|
121 |
-
|
122 |
-
result = None
|
123 |
-
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
|
124 |
-
|
125 |
-
if result is None: result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
|
126 |
-
if ndim <= 2: result = result[0, 0]
|
127 |
-
|
128 |
-
return result
|
129 |
-
|
130 |
-
@nb.jit(nopython=True)
|
131 |
-
def backtrace(trace):
|
132 |
-
i = trace.shape[0] - 1
|
133 |
-
j = trace.shape[1] - 1
|
134 |
-
|
135 |
-
trace[0, :] = 2
|
136 |
-
trace[:, 0] = 1
|
137 |
-
|
138 |
-
result = []
|
139 |
-
while i > 0 or j > 0:
|
140 |
-
result.append((i - 1, j - 1))
|
141 |
-
|
142 |
-
if trace[i, j] == 0:
|
143 |
-
i -= 1
|
144 |
-
j -= 1
|
145 |
-
elif trace[i, j] == 1: i -= 1
|
146 |
-
elif trace[i, j] == 2: j -= 1
|
147 |
-
else: raise ValueError
|
148 |
-
|
149 |
-
return np.array(result)[::-1, :].T
|
150 |
-
|
151 |
-
|
152 |
-
@nb.jit(nopython=True, parallel=True)
|
153 |
-
def dtw_cpu(x):
|
154 |
-
N, M = x.shape
|
155 |
-
|
156 |
-
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
|
157 |
-
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
|
158 |
-
cost[0, 0] = 0
|
159 |
-
|
160 |
-
for j in range(1, M + 1):
|
161 |
-
for i in range(1, N + 1):
|
162 |
-
c0 = cost[i - 1, j - 1]
|
163 |
-
c1 = cost[i - 1, j]
|
164 |
-
c2 = cost[i, j - 1]
|
165 |
-
|
166 |
-
if c0 < c1 and c0 < c2: c, t = c0, 0
|
167 |
-
elif c1 < c0 and c1 < c2: c, t = c1, 1
|
168 |
-
else: c, t = c2, 2
|
169 |
-
|
170 |
-
cost[i, j] = x[i - 1, j - 1] + c
|
171 |
-
trace[i, j] = t
|
172 |
-
|
173 |
-
return backtrace(trace)
|
174 |
-
|
175 |
-
def dtw(x):
|
176 |
-
return dtw_cpu(x.double().cpu().numpy())
|
177 |
-
|
178 |
-
def find_alignment(model, tokenizer, text_tokens, mel, num_frames, *, medfilt_width = 7, qk_scale = 1.0):
|
179 |
-
if len(text_tokens) == 0: return []
|
180 |
-
|
181 |
-
tokens = torch.tensor([*tokenizer.sot_sequence, tokenizer.no_timestamps, *text_tokens, tokenizer.eot]).to(model.device)
|
182 |
-
|
183 |
-
QKs = [None] * model.dims.n_text_layer
|
184 |
-
hooks = [block.cross_attn.register_forward_hook(lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])) for i, block in enumerate(model.decoder.blocks)]
|
185 |
-
|
186 |
-
with torch.no_grad(), disable_sdpa():
|
187 |
-
token_probs = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0][len(tokenizer.sot_sequence) :, : tokenizer.eot].softmax(dim=-1)
|
188 |
-
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens].tolist()
|
189 |
-
|
190 |
-
for hook in hooks:
|
191 |
-
hook.remove()
|
192 |
-
|
193 |
-
weights = (torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])[:, :, : num_frames // 2] * qk_scale).softmax(dim=-1)
|
194 |
-
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
195 |
-
weights = median_filter((weights - mean) / std, medfilt_width)
|
196 |
-
|
197 |
-
text_indices, time_indices = dtw(-weights.mean(axis=0)[len(tokenizer.sot_sequence) : -1])
|
198 |
-
|
199 |
-
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
200 |
-
if len(word_tokens) <= 1: return []
|
201 |
-
|
202 |
-
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
203 |
-
jump_times = time_indices[np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)] / TOKENS_PER_SECOND
|
204 |
-
|
205 |
-
return [WordTiming(word, tokens, start, end, probability) for word, tokens, start, end, probability in zip(words, word_tokens, jump_times[word_boundaries[:-1]], jump_times[word_boundaries[1:]], [np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:])])]
|
206 |
-
|
207 |
-
def add_word_timestamps(*, segments, model, tokenizer, mel, num_frames, prepend_punctuations = "\"'“¿([{-", append_punctuations = "\"'.。,,!!??::”)]}、", last_speech_timestamp, **kwargs):
|
208 |
-
if len(segments) == 0: return
|
209 |
-
|
210 |
-
text_tokens_per_segment = [[token for token in segment["tokens"] if token < tokenizer.eot] for segment in segments]
|
211 |
-
|
212 |
-
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
|
213 |
-
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
|
214 |
-
|
215 |
-
word_durations = np.array([t.end - t.start for t in alignment])
|
216 |
-
word_durations = word_durations[word_durations.nonzero()]
|
217 |
-
|
218 |
-
median_duration = min(0.7, float(np.median(word_durations) if len(word_durations) > 0 else 0.0))
|
219 |
-
max_duration = median_duration * 2
|
220 |
-
|
221 |
-
if len(word_durations) > 0:
|
222 |
-
sentence_end_marks = ".。!!??"
|
223 |
-
for i in range(1, len(alignment)):
|
224 |
-
if alignment[i].end - alignment[i].start > max_duration:
|
225 |
-
if alignment[i].word in sentence_end_marks: alignment[i].end = alignment[i].start + max_duration
|
226 |
-
elif alignment[i - 1].word in sentence_end_marks: alignment[i].start = alignment[i].end - max_duration
|
227 |
-
|
228 |
-
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
229 |
-
|
230 |
-
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
|
231 |
-
word_index = 0
|
232 |
-
|
233 |
-
for segment, text_tokens in zip(segments, text_tokens_per_segment):
|
234 |
-
saved_tokens = 0
|
235 |
-
words = []
|
236 |
-
|
237 |
-
while word_index < len(alignment) and saved_tokens < len(text_tokens):
|
238 |
-
timing = alignment[word_index]
|
239 |
-
|
240 |
-
if timing.word: words.append(dict(word=timing.word, start=round(time_offset + timing.start, 2), end=round(time_offset + timing.end, 2), probability=timing.probability))
|
241 |
-
|
242 |
-
saved_tokens += len(timing.tokens)
|
243 |
-
word_index += 1
|
244 |
-
|
245 |
-
if len(words) > 0:
|
246 |
-
if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (words[0]["end"] - words[0]["start"] > max_duration or (len(words) > 1 and words[1]["end"] - words[0]["start"] > max_duration * 2)):
|
247 |
-
if (len(words) > 1 and words[1]["end"] - words[1]["start"] > max_duration): words[0]["end"] = words[1]["start"] = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
|
248 |
-
words[0]["start"] = max(0, words[0]["end"] - max_duration)
|
249 |
-
|
250 |
-
if (segment["start"] < words[0]["end"] and segment["start"] - 0.5 > words[0]["start"]): words[0]["start"] = max(0, min(words[0]["end"] - median_duration, segment["start"]))
|
251 |
-
else: segment["start"] = words[0]["start"]
|
252 |
-
|
253 |
-
if (segment["end"] > words[-1]["start"] and segment["end"] + 0.5 < words[-1]["end"]): words[-1]["end"] = max(words[-1]["start"] + median_duration, segment["end"])
|
254 |
-
else: segment["end"] = words[-1]["end"]
|
255 |
-
|
256 |
-
last_speech_timestamp = segment["end"]
|
257 |
-
|
258 |
-
segment["words"] = words
|
259 |
-
|
260 |
-
@lru_cache(maxsize=None)
|
261 |
-
def mel_filters(device, n_mels):
|
262 |
-
assert n_mels in {80, 128}
|
263 |
-
|
264 |
-
with np.load(os.path.join("assets", "models", "speaker_diarization", "assets", "mel_filters.npz"), allow_pickle=False) as f:
|
265 |
-
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
266 |
-
|
267 |
-
def log_mel_spectrogram(audio, n_mels = 80, padding = 0, device = None):
|
268 |
-
if not torch.is_tensor(audio):
|
269 |
-
if isinstance(audio, str): audio = load_audio(logging.getLogger(__name__), audio, sample_rate=SAMPLE_RATE).astype(np.float32)
|
270 |
-
audio = torch.from_numpy(audio)
|
271 |
-
|
272 |
-
if device is not None: audio = audio.to(device)
|
273 |
-
if padding > 0: audio = F.pad(audio, (0, padding))
|
274 |
-
|
275 |
-
log_spec = torch.clamp(mel_filters(audio.device, n_mels) @ torch.stft(audio, N_FFT, HOP_LENGTH, window=torch.hann_window(N_FFT).to(audio.device), return_complex=True)[..., :-1].abs() ** 2, min=1e-10).log10()
|
276 |
-
return (torch.maximum(log_spec, log_spec.max() - 8.0) + 4.0) / 4.0
|
277 |
-
|
278 |
-
def pad_or_trim(array, length = N_SAMPLES, *, axis = -1):
|
279 |
-
if torch.is_tensor(array):
|
280 |
-
if array.shape[axis] > length: array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))
|
281 |
-
|
282 |
-
if array.shape[axis] < length:
|
283 |
-
pad_widths = [(0, 0)] * array.ndim
|
284 |
-
pad_widths[axis] = (0, length - array.shape[axis])
|
285 |
-
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
|
286 |
-
else:
|
287 |
-
if array.shape[axis] > length: array = array.take(indices=range(length), axis=axis)
|
288 |
-
|
289 |
-
if array.shape[axis] < length:
|
290 |
-
pad_widths = [(0, 0)] * array.ndim
|
291 |
-
pad_widths[axis] = (0, length - array.shape[axis])
|
292 |
-
array = np.pad(array, pad_widths)
|
293 |
-
|
294 |
-
return array
|
295 |
-
|
296 |
-
def get_end(segments):
|
297 |
-
return next((w["end"] for s in reversed(segments) for w in reversed(s["words"])), segments[-1]["end"] if segments else None)
|
298 |
-
|
299 |
-
def transcribe_function(model, audio, *, verbose = None, temperature = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), compression_ratio_threshold = 2.4, logprob_threshold = -1.0, no_speech_threshold = 0.6, condition_on_previous_text = True, initial_prompt = None, carry_initial_prompt = False, word_timestamps = False, prepend_punctuations = "\"'“¿([{-", append_punctuations = "\"'.。,,!!??::”)]}、", clip_timestamps = "0", hallucination_silence_threshold = None, fp16 = False, **decode_options):
|
300 |
-
dtype = torch.float32
|
301 |
-
decode_options["fp16"] = fp16
|
302 |
-
|
303 |
-
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
|
304 |
-
content_frames = mel.shape[-1] - N_FRAMES
|
305 |
-
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
|
306 |
-
|
307 |
-
if decode_options.get("language", None) is None:
|
308 |
-
if not model.is_multilingual: decode_options["language"] = "vi"
|
309 |
-
else:
|
310 |
-
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
311 |
-
_, probs = model.detect_language(mel_segment)
|
312 |
-
decode_options["language"] = max(probs, key=probs.get)
|
313 |
-
|
314 |
-
if verbose is not None: print(f"{LANGUAGES[decode_options['language']].title()}")
|
315 |
-
|
316 |
-
language = decode_options["language"]
|
317 |
-
task = decode_options.get("task", "transcribe")
|
318 |
-
tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages, language=language, task=task)
|
319 |
-
|
320 |
-
if isinstance(clip_timestamps, str): clip_timestamps = [float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])]
|
321 |
-
seek_points = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
|
322 |
-
|
323 |
-
if len(seek_points) == 0: seek_points.append(0)
|
324 |
-
if len(seek_points) % 2 == 1: seek_points.append(content_frames)
|
325 |
-
|
326 |
-
seek_clips = list(zip(seek_points[::2], seek_points[1::2]))
|
327 |
-
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
328 |
-
|
329 |
-
def decode_with_fallback(segment):
|
330 |
-
temperatures = ([temperature] if isinstance(temperature, (int, float)) else temperature)
|
331 |
-
decode_result = None
|
332 |
-
|
333 |
-
for t in temperatures:
|
334 |
-
kwargs = {**decode_options}
|
335 |
-
|
336 |
-
if t > 0:
|
337 |
-
kwargs.pop("beam_size", None)
|
338 |
-
kwargs.pop("patience", None)
|
339 |
-
else: kwargs.pop("best_of", None)
|
340 |
-
|
341 |
-
decode_result = model.decode(segment, DecodingOptions(**kwargs, temperature=t))
|
342 |
-
needs_fallback = False
|
343 |
-
|
344 |
-
if (compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold): needs_fallback = True
|
345 |
-
if (logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold): needs_fallback = True
|
346 |
-
if (no_speech_threshold is not None and decode_result.no_speech_prob > no_speech_threshold and logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold): needs_fallback = False
|
347 |
-
if not needs_fallback: break
|
348 |
-
|
349 |
-
return decode_result
|
350 |
-
|
351 |
-
clip_idx = 0
|
352 |
-
seek = seek_clips[clip_idx][0]
|
353 |
-
|
354 |
-
input_stride = exact_div(N_FRAMES, model.dims.n_audio_ctx)
|
355 |
-
time_precision = (input_stride * HOP_LENGTH / SAMPLE_RATE)
|
356 |
-
|
357 |
-
all_tokens, all_segments = [], []
|
358 |
-
prompt_reset_since = 0
|
359 |
-
|
360 |
-
remaining_prompt_length = model.dims.n_text_ctx // 2 - 1
|
361 |
-
|
362 |
-
if initial_prompt is not None:
|
363 |
-
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
364 |
-
all_tokens.extend(initial_prompt_tokens)
|
365 |
-
remaining_prompt_length -= len(initial_prompt_tokens)
|
366 |
-
else: initial_prompt_tokens = []
|
367 |
-
|
368 |
-
def new_segment(*, start, end, tokens, result):
|
369 |
-
tokens = tokens.tolist()
|
370 |
-
return {"seek": seek, "start": start, "end": end, "text": tokenizer.decode([token for token in tokens if token < tokenizer.eot]), "tokens": tokens, "temperature": result.temperature, "avg_logprob": result.avg_logprob, "compression_ratio": result.compression_ratio, "no_speech_prob": result.no_speech_prob}
|
371 |
-
|
372 |
-
with tqdm.tqdm(total=content_frames, unit="frames", disable=verbose is not False) as pbar:
|
373 |
-
last_speech_timestamp = 0.0
|
374 |
-
while clip_idx < len(seek_clips):
|
375 |
-
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
|
376 |
-
if seek < seek_clip_start: seek = seek_clip_start
|
377 |
-
|
378 |
-
if seek >= seek_clip_end:
|
379 |
-
clip_idx += 1
|
380 |
-
if clip_idx < len(seek_clips): seek = seek_clips[clip_idx][0]
|
381 |
-
continue
|
382 |
-
|
383 |
-
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
384 |
-
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
|
385 |
-
|
386 |
-
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
|
387 |
-
mel_segment = mel[:, seek : seek + segment_size]
|
388 |
-
|
389 |
-
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
390 |
-
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
391 |
-
|
392 |
-
if carry_initial_prompt: decode_options["prompt"] = initial_prompt_tokens + all_tokens[max(len(initial_prompt_tokens), prompt_reset_since):][-remaining_prompt_length:]
|
393 |
-
else: decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
394 |
-
|
395 |
-
result = decode_with_fallback(mel_segment)
|
396 |
-
tokens = torch.tensor(result.tokens)
|
397 |
-
|
398 |
-
if no_speech_threshold is not None:
|
399 |
-
should_skip = result.no_speech_prob > no_speech_threshold
|
400 |
-
if (logprob_threshold is not None and result.avg_logprob > logprob_threshold):
|
401 |
-
should_skip = False
|
402 |
-
|
403 |
-
if should_skip:
|
404 |
-
seek += segment_size
|
405 |
-
continue
|
406 |
-
|
407 |
-
previous_seek = seek
|
408 |
-
current_segments = []
|
409 |
-
|
410 |
-
def word_anomaly_score(word):
|
411 |
-
probability = word.get("probability", 0.0)
|
412 |
-
duration = word["end"] - word["start"]
|
413 |
-
score = 0.0
|
414 |
-
|
415 |
-
if probability < 0.15: score += 1.0
|
416 |
-
if duration < 0.133: score += (0.133 - duration) * 15
|
417 |
-
if duration > 2.0: score += duration - 2.0
|
418 |
-
|
419 |
-
return score
|
420 |
-
|
421 |
-
def is_segment_anomaly(segment):
|
422 |
-
if segment is None or not segment["words"]: return False
|
423 |
-
|
424 |
-
words = [w for w in segment["words"] if w["word"] not in punctuation]
|
425 |
-
words = words[:8]
|
426 |
-
|
427 |
-
score = sum(word_anomaly_score(w) for w in words)
|
428 |
-
|
429 |
-
return score >= 3 or score + 0.01 >= len(words)
|
430 |
-
|
431 |
-
def next_words_segment(segments):
|
432 |
-
return next((s for s in segments if s["words"]), None)
|
433 |
-
|
434 |
-
timestamp_tokens = tokens.ge(tokenizer.timestamp_begin)
|
435 |
-
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
436 |
-
|
437 |
-
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
438 |
-
consecutive.add_(1)
|
439 |
-
|
440 |
-
if len(consecutive) > 0:
|
441 |
-
slices = consecutive.tolist()
|
442 |
-
if single_timestamp_ending:
|
443 |
-
slices.append(len(tokens))
|
444 |
-
|
445 |
-
last_slice = 0
|
446 |
-
for current_slice in slices:
|
447 |
-
sliced_tokens = tokens[last_slice:current_slice]
|
448 |
-
current_segments.append(new_segment(start=time_offset + (sliced_tokens[0].item() - tokenizer.timestamp_begin) * time_precision, end=time_offset + (sliced_tokens[-1].item() - tokenizer.timestamp_begin) * time_precision, tokens=sliced_tokens, result=result))
|
449 |
-
last_slice = current_slice
|
450 |
-
|
451 |
-
if single_timestamp_ending: seek += segment_size
|
452 |
-
else: seek += (tokens[last_slice - 1].item() - tokenizer.timestamp_begin) * input_stride
|
453 |
-
else:
|
454 |
-
duration = segment_duration
|
455 |
-
|
456 |
-
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
457 |
-
if (len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin): duration = (timestamps[-1].item() - tokenizer.timestamp_begin) * time_precision
|
458 |
-
|
459 |
-
current_segments.append(new_segment(start=time_offset, end=time_offset + duration, tokens=tokens, result=result))
|
460 |
-
seek += segment_size
|
461 |
-
|
462 |
-
if word_timestamps:
|
463 |
-
add_word_timestamps(segments=current_segments, model=model, tokenizer=tokenizer, mel=mel_segment, num_frames=segment_size, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, last_speech_timestamp=last_speech_timestamp)
|
464 |
-
|
465 |
-
if not single_timestamp_ending:
|
466 |
-
last_word_end = get_end(current_segments)
|
467 |
-
if last_word_end is not None and last_word_end > time_offset: seek = round(last_word_end * FRAMES_PER_SECOND)
|
468 |
-
|
469 |
-
if hallucination_silence_threshold is not None:
|
470 |
-
threshold = hallucination_silence_threshold
|
471 |
-
|
472 |
-
if not single_timestamp_ending:
|
473 |
-
last_word_end = get_end(current_segments)
|
474 |
-
if last_word_end is not None and last_word_end > time_offset: seek = round(last_word_end * FRAMES_PER_SECOND) if (window_end_time - last_word_end) > threshold else (previous_seek + segment_size)
|
475 |
-
|
476 |
-
first_segment = next_words_segment(current_segments)
|
477 |
-
|
478 |
-
if first_segment is not None and is_segment_anomaly(first_segment):
|
479 |
-
gap = first_segment["start"] - time_offset
|
480 |
-
|
481 |
-
if gap > threshold:
|
482 |
-
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
|
483 |
-
continue
|
484 |
-
|
485 |
-
hal_last_end = last_speech_timestamp
|
486 |
-
|
487 |
-
for si in range(len(current_segments)):
|
488 |
-
segment = current_segments[si]
|
489 |
-
if not segment["words"]: continue
|
490 |
-
|
491 |
-
if is_segment_anomaly(segment):
|
492 |
-
next_segment = next_words_segment(current_segments[si + 1 :])
|
493 |
-
hal_next_start = next_segment["words"][0]["start"] if next_segment is not None else (time_offset + segment_duration)
|
494 |
-
|
495 |
-
if (segment["start"] - hal_last_end > threshold or segment["start"] < threshold or segment["start"] - time_offset < 2.0) and (hal_next_start - segment["end"] > threshold or is_segment_anomaly(next_segment) or window_end_time - segment["end"] < 2.0):
|
496 |
-
seek = round(max(time_offset + 1, segment["start"]) * FRAMES_PER_SECOND)
|
497 |
-
if content_duration - segment["end"] < threshold: seek = content_frames
|
498 |
-
|
499 |
-
current_segments[si:] = []
|
500 |
-
break
|
501 |
-
|
502 |
-
hal_last_end = segment["end"]
|
503 |
-
|
504 |
-
last_word_end = get_end(current_segments)
|
505 |
-
if last_word_end is not None: last_speech_timestamp = last_word_end
|
506 |
-
|
507 |
-
for _, segment in enumerate(current_segments):
|
508 |
-
if segment["start"] == segment["end"] or segment["text"].strip() == "":
|
509 |
-
segment["text"] = ""
|
510 |
-
segment["tokens"] = []
|
511 |
-
segment["words"] = []
|
512 |
-
|
513 |
-
all_segments.extend([{"id": i, **segment} for i, segment in enumerate(current_segments, start=len(all_segments))])
|
514 |
-
all_tokens.extend([token for segment in current_segments for token in segment["tokens"]])
|
515 |
-
|
516 |
-
if not condition_on_previous_text or result.temperature > 0.5: prompt_reset_since = len(all_tokens)
|
517 |
-
pbar.update(min(content_frames, seek) - previous_seek)
|
518 |
-
|
519 |
-
return dict(text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), segments=all_segments, language=language)
|
520 |
-
|
521 |
-
def compression_ratio(text):
|
522 |
-
text_bytes = text.encode("utf-8")
|
523 |
-
return len(text_bytes) / len(zlib.compress(text_bytes))
|
524 |
-
|
525 |
-
def sinusoids(length, channels, max_timescale=10000):
|
526 |
-
assert channels % 2 == 0
|
527 |
-
|
528 |
-
scaled_time = torch.arange(length)[:, np.newaxis] * torch.exp(-(np.log(max_timescale) / (channels // 2 - 1)) * torch.arange(channels // 2))[np.newaxis, :]
|
529 |
-
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
530 |
-
|
531 |
-
@torch.no_grad()
|
532 |
-
def detect_language_function(model, mel, tokenizer = None):
|
533 |
-
if tokenizer is None: tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages)
|
534 |
-
if (tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence): raise ValueError
|
535 |
-
|
536 |
-
single = mel.ndim == 2
|
537 |
-
|
538 |
-
if single: mel = mel.unsqueeze(0)
|
539 |
-
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state): mel = model.encoder(mel)
|
540 |
-
|
541 |
-
n_audio = mel.shape[0]
|
542 |
-
logits = model.logits(torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device), mel)[:, 0]
|
543 |
-
|
544 |
-
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
545 |
-
mask[list(tokenizer.all_language_tokens)] = False
|
546 |
-
|
547 |
-
logits[:, mask] = -np.inf
|
548 |
-
|
549 |
-
language_tokens = logits.argmax(dim=-1)
|
550 |
-
language_probs = [{c: logits.softmax(dim=-1).cpu()[i, j].item() for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)} for i in range(n_audio)]
|
551 |
-
|
552 |
-
if single:
|
553 |
-
language_tokens = language_tokens[0]
|
554 |
-
language_probs = language_probs[0]
|
555 |
-
|
556 |
-
return language_tokens, language_probs
|
557 |
-
|
558 |
-
@lru_cache(maxsize=None)
|
559 |
-
def get_tokenizer(multilingual, *, num_languages = 99, language = None, task = None):
|
560 |
-
if language is not None:
|
561 |
-
language = language.lower()
|
562 |
-
if language not in LANGUAGES:
|
563 |
-
if language in TO_LANGUAGE_CODE: language = TO_LANGUAGE_CODE[language]
|
564 |
-
else: raise ValueError
|
565 |
-
|
566 |
-
if multilingual:
|
567 |
-
encoding_name = "multilingual"
|
568 |
-
language = language or "en"
|
569 |
-
task = task or "transcribe"
|
570 |
-
else:
|
571 |
-
encoding_name = "gpt2"
|
572 |
-
language = None
|
573 |
-
task = None
|
574 |
-
|
575 |
-
return Tokenizer(encoding_name=encoding_name, num_languages=num_languages, language=language, task=task)
|
576 |
-
|
577 |
-
@lru_cache(maxsize=None)
|
578 |
-
def get_encoding(name = "gpt2", num_languages = 99):
|
579 |
-
vocab_path = os.path.join("assets", "models", "speaker_diarization", "assets", f"{name}.tiktoken")
|
580 |
-
ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in open(vocab_path) if line)}
|
581 |
-
|
582 |
-
n_vocab = len(ranks)
|
583 |
-
special_tokens = {}
|
584 |
-
|
585 |
-
specials = ["<|endoftext|>", "<|startoftranscript|>", *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], "<|translate|>", "<|transcribe|>", "<|startoflm|>", "<|startofprev|>", "<|nospeech|>", "<|notimestamps|>", *[f"<|{i * 0.02:.2f}|>" for i in range(1501)]]
|
586 |
-
|
587 |
-
for token in specials:
|
588 |
-
special_tokens[token] = n_vocab
|
589 |
-
n_vocab += 1
|
590 |
-
|
591 |
-
return tiktoken.Encoding(name=os.path.basename(vocab_path), explicit_n_vocab=n_vocab, pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", mergeable_ranks=ranks, special_tokens=special_tokens)
|
592 |
-
|
593 |
-
class DecodingOptions:
|
594 |
-
def __init__(self, task = "transcribe", language = None, temperature = 0.0, sample_len = None, best_of = None, beam_size = None, patience = None, length_penalty = None, prompt = None, prefix = None, suppress_tokens = "-1", suppress_blank = True, without_timestamps = False, max_initial_timestamp = 1.0, fp16 = False):
|
595 |
-
self.task = task
|
596 |
-
self.language = language
|
597 |
-
self.temperature = temperature
|
598 |
-
self.sample_len = sample_len
|
599 |
-
self.best_of = best_of
|
600 |
-
self.beam_size = beam_size
|
601 |
-
self.patience = patience
|
602 |
-
self.length_penalty = length_penalty
|
603 |
-
self.prompt = prompt
|
604 |
-
self.prefix = prefix
|
605 |
-
self.suppress_tokens = suppress_tokens
|
606 |
-
self.suppress_blank = suppress_blank
|
607 |
-
self.without_timestamps = without_timestamps
|
608 |
-
self.max_initial_timestamp = max_initial_timestamp
|
609 |
-
self.fp16 = fp16
|
610 |
-
|
611 |
-
@torch.no_grad()
|
612 |
-
def decode_function(model, mel, options = DecodingOptions(), **kwargs):
|
613 |
-
if single := mel.ndim == 2: mel = mel.unsqueeze(0)
|
614 |
-
if kwargs: options = replace(options, **kwargs)
|
615 |
-
|
616 |
-
result = DecodingTask(model, options).run(mel)
|
617 |
-
return result[0] if single else result
|
618 |
-
|
619 |
-
@dataclass
|
620 |
-
class ModelDimensions:
|
621 |
-
n_mels: int
|
622 |
-
n_audio_ctx: int
|
623 |
-
n_audio_state: int
|
624 |
-
n_audio_head: int
|
625 |
-
n_audio_layer: int
|
626 |
-
n_vocab: int
|
627 |
-
n_text_ctx: int
|
628 |
-
n_text_state: int
|
629 |
-
n_text_head: int
|
630 |
-
n_text_layer: int
|
631 |
-
|
632 |
-
class LayerNorm(nn.LayerNorm):
|
633 |
-
def forward(self, x):
|
634 |
-
return super().forward(x.float()).type(x.dtype)
|
635 |
-
|
636 |
-
class Linear(nn.Linear):
|
637 |
-
def forward(self, x):
|
638 |
-
return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype))
|
639 |
-
|
640 |
-
class Conv1d(nn.Conv1d):
|
641 |
-
def _conv_forward(self, x, weight, bias):
|
642 |
-
return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
|
643 |
-
|
644 |
-
class TextDecoder(nn.Module):
|
645 |
-
def __init__(self, n_vocab, n_ctx, n_state, n_head, n_layer):
|
646 |
-
super().__init__()
|
647 |
-
|
648 |
-
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
649 |
-
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
650 |
-
|
651 |
-
self.blocks = nn.ModuleList([ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)])
|
652 |
-
self.ln = LayerNorm(n_state)
|
653 |
-
self.register_buffer("mask", torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1), persistent=False)
|
654 |
-
|
655 |
-
def forward(self, x, xa, kv_cache = None):
|
656 |
-
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
657 |
-
x = (self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]).to(xa.dtype)
|
658 |
-
|
659 |
-
for block in self.blocks:
|
660 |
-
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
661 |
-
|
662 |
-
x = self.ln(x)
|
663 |
-
return (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
|
664 |
-
|
665 |
-
class AudioEncoder(nn.Module):
|
666 |
-
def __init__(self, n_mels, n_ctx, n_state, n_head, n_layer):
|
667 |
-
super().__init__()
|
668 |
-
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
669 |
-
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
670 |
-
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
671 |
-
|
672 |
-
self.blocks = nn.ModuleList([ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)])
|
673 |
-
self.ln_post = LayerNorm(n_state)
|
674 |
-
|
675 |
-
def forward(self, x):
|
676 |
-
x = F.gelu(self.conv2(F.gelu(self.conv1(x)))).permute(0, 2, 1)
|
677 |
-
|
678 |
-
assert x.shape[1:] == self.positional_embedding.shape
|
679 |
-
x = (x + self.positional_embedding).to(x.dtype)
|
680 |
-
|
681 |
-
for block in self.blocks:
|
682 |
-
x = block(x)
|
683 |
-
|
684 |
-
return self.ln_post(x)
|
685 |
-
|
686 |
-
class Whisper(nn.Module):
|
687 |
-
def __init__(self, dims):
|
688 |
-
super().__init__()
|
689 |
-
self.dims = dims
|
690 |
-
self.encoder = AudioEncoder(self.dims.n_mels, self.dims.n_audio_ctx, self.dims.n_audio_state, self.dims.n_audio_head, self.dims.n_audio_layer)
|
691 |
-
self.decoder = TextDecoder(self.dims.n_vocab, self.dims.n_text_ctx, self.dims.n_text_state, self.dims.n_text_head, self.dims.n_text_layer)
|
692 |
-
|
693 |
-
all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
|
694 |
-
all_heads[self.dims.n_text_layer // 2 :] = True
|
695 |
-
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
696 |
-
|
697 |
-
def set_alignment_heads(self, dump):
|
698 |
-
self.register_buffer("alignment_heads", torch.from_numpy(np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()).reshape(self.dims.n_text_layer, self.dims.n_text_head).to_sparse(), persistent=False)
|
699 |
-
|
700 |
-
def embed_audio(self, mel):
|
701 |
-
return self.encoder(mel)
|
702 |
-
|
703 |
-
def logits(self, tokens, audio_features):
|
704 |
-
return self.decoder(tokens, audio_features)
|
705 |
-
|
706 |
-
def forward(self, mel, tokens):
|
707 |
-
return self.decoder(tokens, self.encoder(mel))
|
708 |
-
|
709 |
-
@property
|
710 |
-
def device(self):
|
711 |
-
return next(self.parameters()).device
|
712 |
-
|
713 |
-
@property
|
714 |
-
def is_multilingual(self):
|
715 |
-
return self.dims.n_vocab >= 51865
|
716 |
-
|
717 |
-
@property
|
718 |
-
def num_languages(self):
|
719 |
-
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
720 |
-
|
721 |
-
def install_kv_cache_hooks(self, cache = None):
|
722 |
-
cache = {**cache} if cache is not None else {}
|
723 |
-
hooks = []
|
724 |
-
|
725 |
-
def save_to_cache(module, _, output):
|
726 |
-
cache[module] = output if module not in cache or output.shape[1] > self.dims.n_text_ctx else torch.cat([cache[module], output], dim=1).detach()
|
727 |
-
return cache[module]
|
728 |
-
|
729 |
-
def install_hooks(layer: nn.Module):
|
730 |
-
if isinstance(layer, MultiHeadAttention):
|
731 |
-
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
732 |
-
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
733 |
-
|
734 |
-
self.decoder.apply(install_hooks)
|
735 |
-
return cache, hooks
|
736 |
-
|
737 |
-
detect_language = detect_language_function
|
738 |
-
transcribe = transcribe_function
|
739 |
-
decode = decode_function
|
740 |
-
|
741 |
-
class ResidualAttentionBlock(nn.Module):
|
742 |
-
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
743 |
-
super().__init__()
|
744 |
-
|
745 |
-
self.attn = MultiHeadAttention(n_state, n_head)
|
746 |
-
self.attn_ln = LayerNorm(n_state)
|
747 |
-
|
748 |
-
self.cross_attn = (MultiHeadAttention(n_state, n_head) if cross_attention else None)
|
749 |
-
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
750 |
-
|
751 |
-
n_mlp = n_state * 4
|
752 |
-
self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
|
753 |
-
self.mlp_ln = LayerNorm(n_state)
|
754 |
-
|
755 |
-
def forward(self, x, xa = None, mask = None, kv_cache = None):
|
756 |
-
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
757 |
-
if self.cross_attn: x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
758 |
-
|
759 |
-
return x + self.mlp(self.mlp_ln(x))
|
760 |
-
|
761 |
-
class MultiHeadAttention(nn.Module):
|
762 |
-
def __init__(self, n_state, n_head):
|
763 |
-
super().__init__()
|
764 |
-
self.n_head = n_head
|
765 |
-
self.query = Linear(n_state, n_state)
|
766 |
-
self.key = Linear(n_state, n_state, bias=False)
|
767 |
-
self.value = Linear(n_state, n_state)
|
768 |
-
self.out = Linear(n_state, n_state)
|
769 |
-
|
770 |
-
def forward(self, x, xa = None, mask = None, kv_cache = None):
|
771 |
-
k, v = (self.key(x if xa is None else xa), self.value(x if xa is None else xa)) if kv_cache is None or xa is None or self.key not in kv_cache else (kv_cache[self.key], kv_cache[self.value])
|
772 |
-
wv, qk = self.qkv_attention(self.query(x), k, v, mask)
|
773 |
-
|
774 |
-
return self.out(wv), qk
|
775 |
-
|
776 |
-
def qkv_attention(self, q, k, v, mask = None):
|
777 |
-
_, n_ctx, _ = q.shape
|
778 |
-
|
779 |
-
q, k, v = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3), k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3), v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
780 |
-
return scaled_dot_product_attention(q, k, v, is_causal=mask is not None and n_ctx > 1).permute(0, 2, 1, 3).flatten(start_dim=2), None
|
781 |
-
|
782 |
-
class LogitFilter:
|
783 |
-
def apply(self, logits, tokens):
|
784 |
-
pass
|
785 |
-
|
786 |
-
class SuppressBlank(LogitFilter):
|
787 |
-
def __init__(self, tokenizer, sample_begin):
|
788 |
-
self.tokenizer = tokenizer
|
789 |
-
self.sample_begin = sample_begin
|
790 |
-
|
791 |
-
def apply(self, logits, tokens):
|
792 |
-
if tokens.shape[1] == self.sample_begin: logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
793 |
-
|
794 |
-
class SuppressTokens(LogitFilter):
|
795 |
-
def __init__(self, suppress_tokens):
|
796 |
-
self.suppress_tokens = list(suppress_tokens)
|
797 |
-
|
798 |
-
def apply(self, logits, tokens):
|
799 |
-
logits[:, self.suppress_tokens] = -np.inf
|
800 |
-
|
801 |
-
class Inference:
|
802 |
-
def logits(self, tokens, audio_features):
|
803 |
-
pass
|
804 |
-
|
805 |
-
def rearrange_kv_cache(self, source_indices):
|
806 |
-
pass
|
807 |
-
|
808 |
-
def cleanup_caching(self):
|
809 |
-
pass
|
810 |
-
|
811 |
-
class PyTorchInference(Inference):
|
812 |
-
def __init__(self, model, initial_token_length):
|
813 |
-
self.model = model
|
814 |
-
self.initial_token_length = initial_token_length
|
815 |
-
self.kv_cache = {}
|
816 |
-
self.hooks = []
|
817 |
-
|
818 |
-
self.kv_modules = [block.attn.key for block in self.model.decoder.blocks] + [block.attn.value for block in self.model.decoder.blocks]
|
819 |
-
|
820 |
-
def logits(self, tokens, audio_features):
|
821 |
-
if not self.kv_cache: self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
822 |
-
if tokens.shape[-1] > self.initial_token_length: tokens = tokens[:, -1:]
|
823 |
-
|
824 |
-
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
825 |
-
|
826 |
-
def cleanup_caching(self):
|
827 |
-
for hook in self.hooks:
|
828 |
-
hook.remove()
|
829 |
-
|
830 |
-
self.kv_cache = {}
|
831 |
-
self.hooks = []
|
832 |
-
|
833 |
-
def rearrange_kv_cache(self, source_indices):
|
834 |
-
if source_indices != list(range(len(source_indices))):
|
835 |
-
for module in self.kv_modules:
|
836 |
-
self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
|
837 |
-
|
838 |
-
class SequenceRanker:
|
839 |
-
def rank(self, tokens, sum_logprobs):
|
840 |
-
pass
|
841 |
-
|
842 |
-
class MaximumLikelihoodRanker(SequenceRanker):
|
843 |
-
def __init__(self, length_penalty):
|
844 |
-
self.length_penalty = length_penalty
|
845 |
-
|
846 |
-
def rank(self, tokens, sum_logprobs):
|
847 |
-
def scores(logprobs, lengths):
|
848 |
-
result = []
|
849 |
-
for logprob, length in zip(logprobs, lengths):
|
850 |
-
result.append(logprob / (length if self.length_penalty is None else ((5 + length) / 6) ** self.length_penalty))
|
851 |
-
return result
|
852 |
-
|
853 |
-
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, [[len(t) for t in s] for s in tokens])]
|
854 |
-
|
855 |
-
class TokenDecoder:
|
856 |
-
def reset(self):
|
857 |
-
pass
|
858 |
-
|
859 |
-
def update(self, tokens, logits, sum_logprobs):
|
860 |
-
pass
|
861 |
-
|
862 |
-
def finalize(self, tokens, sum_logprobs):
|
863 |
-
pass
|
864 |
-
|
865 |
-
|
866 |
-
class GreedyDecoder(TokenDecoder):
|
867 |
-
def __init__(self, temperature, eot):
|
868 |
-
self.temperature = temperature
|
869 |
-
self.eot = eot
|
870 |
-
|
871 |
-
def update(self, tokens, logits, sum_logprobs):
|
872 |
-
next_tokens = logits.argmax(dim=-1) if self.temperature == 0 else Categorical(logits=logits / self.temperature).sample()
|
873 |
-
|
874 |
-
logprobs = F.log_softmax(logits.float(), dim=-1)
|
875 |
-
sum_logprobs += logprobs[torch.arange(logprobs.shape[0]), next_tokens] * (tokens[:, -1] != self.eot)
|
876 |
-
|
877 |
-
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
878 |
-
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
|
879 |
-
|
880 |
-
return tokens, (tokens[:, -1] == self.eot).all()
|
881 |
-
|
882 |
-
def finalize(self, tokens, sum_logprobs):
|
883 |
-
return F.pad(tokens, (0, 1), value=self.eot), sum_logprobs.tolist()
|
884 |
-
|
885 |
-
class BeamSearchDecoder(TokenDecoder):
|
886 |
-
def __init__(self, beam_size, eot, inference, patience = None):
|
887 |
-
self.beam_size = beam_size
|
888 |
-
self.eot = eot
|
889 |
-
self.inference = inference
|
890 |
-
self.patience = patience or 1.0
|
891 |
-
self.max_candidates = round(beam_size * self.patience)
|
892 |
-
self.finished_sequences = None
|
893 |
-
|
894 |
-
assert (self.max_candidates > 0)
|
895 |
-
|
896 |
-
def reset(self):
|
897 |
-
self.finished_sequences = None
|
898 |
-
|
899 |
-
def update(self, tokens, logits, sum_logprobs):
|
900 |
-
if tokens.shape[0] % self.beam_size != 0: raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
901 |
-
|
902 |
-
n_audio = tokens.shape[0] // self.beam_size
|
903 |
-
if self.finished_sequences is None: self.finished_sequences = [{} for _ in range(n_audio)]
|
904 |
-
|
905 |
-
logprobs = F.log_softmax(logits.float(), dim=-1)
|
906 |
-
next_tokens, source_indices, finished_sequences = [], [], []
|
907 |
-
|
908 |
-
for i in range(n_audio):
|
909 |
-
scores, sources, finished = {}, {}, {}
|
910 |
-
|
911 |
-
for j in range(self.beam_size):
|
912 |
-
idx = i * self.beam_size + j
|
913 |
-
prefix = tokens[idx].tolist()
|
914 |
-
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
|
915 |
-
sequence = tuple(prefix + [token.item()])
|
916 |
-
scores[sequence] = (sum_logprobs[idx] + logprob).item()
|
917 |
-
sources[sequence] = idx
|
918 |
-
|
919 |
-
saved = 0
|
920 |
-
|
921 |
-
for sequence in sorted(scores, key=scores.get, reverse=True):
|
922 |
-
if sequence[-1] == self.eot: finished[sequence] = scores[sequence]
|
923 |
-
else:
|
924 |
-
sum_logprobs[len(next_tokens)] = scores[sequence]
|
925 |
-
next_tokens.append(sequence)
|
926 |
-
source_indices.append(sources[sequence])
|
927 |
-
|
928 |
-
saved += 1
|
929 |
-
if saved == self.beam_size: break
|
930 |
-
|
931 |
-
finished_sequences.append(finished)
|
932 |
-
|
933 |
-
self.inference.rearrange_kv_cache(source_indices)
|
934 |
-
assert len(self.finished_sequences) == len(finished_sequences)
|
935 |
-
|
936 |
-
for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
|
937 |
-
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
938 |
-
if len(previously_finished) >= self.max_candidates: break
|
939 |
-
previously_finished[seq] = newly_finished[seq]
|
940 |
-
|
941 |
-
return torch.tensor(next_tokens, device=tokens.device), all(len(sequences) >= self.max_candidates for sequences in self.finished_sequences)
|
942 |
-
|
943 |
-
def finalize(self, preceding_tokens, sum_logprobs):
|
944 |
-
sum_logprobs = sum_logprobs.cpu()
|
945 |
-
|
946 |
-
for i, sequences in enumerate(self.finished_sequences):
|
947 |
-
if (len(sequences) < self.beam_size):
|
948 |
-
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
949 |
-
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
950 |
-
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
951 |
-
if len(sequences) >= self.beam_size: break
|
952 |
-
|
953 |
-
return [[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences], [list(sequences.values()) for sequences in self.finished_sequences]
|
954 |
-
|
955 |
-
class ApplyTimestampRules(LogitFilter):
|
956 |
-
def __init__(self, tokenizer, sample_begin, max_initial_timestamp_index):
|
957 |
-
self.tokenizer = tokenizer
|
958 |
-
self.sample_begin = sample_begin
|
959 |
-
self.max_initial_timestamp_index = max_initial_timestamp_index
|
960 |
-
|
961 |
-
def apply(self, logits, tokens):
|
962 |
-
if self.tokenizer.no_timestamps is not None: logits[:, self.tokenizer.no_timestamps] = -np.inf
|
963 |
-
|
964 |
-
for k in range(tokens.shape[0]):
|
965 |
-
sampled_tokens = tokens[k, self.sample_begin :]
|
966 |
-
seq = [t for t in sampled_tokens.tolist()]
|
967 |
-
|
968 |
-
last_was_timestamp = (len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin)
|
969 |
-
penultimate_was_timestamp = (len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin)
|
970 |
-
|
971 |
-
if last_was_timestamp:
|
972 |
-
if penultimate_was_timestamp: logits[k, self.tokenizer.timestamp_begin :] = -np.inf
|
973 |
-
else: logits[k, : self.tokenizer.eot] = -np.inf
|
974 |
-
|
975 |
-
timestamps = sampled_tokens[sampled_tokens.ge(self.tokenizer.timestamp_begin)]
|
976 |
-
|
977 |
-
if timestamps.numel() > 0: logits[k, self.tokenizer.timestamp_begin : timestamps[-1] if last_was_timestamp and not penultimate_was_timestamp else (timestamps[-1] + 1)] = -np.inf
|
978 |
-
|
979 |
-
if tokens.shape[1] == self.sample_begin:
|
980 |
-
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
|
981 |
-
|
982 |
-
if self.max_initial_timestamp_index is not None:
|
983 |
-
last_allowed = (self.tokenizer.timestamp_begin + self.max_initial_timestamp_index)
|
984 |
-
logits[:, last_allowed + 1 :] = -np.inf
|
985 |
-
|
986 |
-
logprobs = F.log_softmax(logits.float(), dim=-1)
|
987 |
-
for k in range(tokens.shape[0]):
|
988 |
-
if logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1) > logprobs[k, : self.tokenizer.timestamp_begin].max(): logits[k, : self.tokenizer.timestamp_begin] = -np.inf
|
989 |
-
|
990 |
-
class DecodingTask:
|
991 |
-
def __init__(self, model, options):
|
992 |
-
self.model = model
|
993 |
-
|
994 |
-
language = options.language or "en"
|
995 |
-
tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages, language=language, task=options.task)
|
996 |
-
|
997 |
-
self.tokenizer = tokenizer
|
998 |
-
self.options = self._verify_options(options)
|
999 |
-
|
1000 |
-
self.n_group = options.beam_size or options.best_of or 1
|
1001 |
-
self.n_ctx = model.dims.n_text_ctx
|
1002 |
-
self.sample_len = options.sample_len or model.dims.n_text_ctx // 2
|
1003 |
-
|
1004 |
-
self.sot_sequence = tokenizer.sot_sequence
|
1005 |
-
if self.options.without_timestamps: self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
|
1006 |
-
|
1007 |
-
self.initial_tokens = self._get_initial_tokens()
|
1008 |
-
self.sample_begin = len(self.initial_tokens)
|
1009 |
-
self.sot_index = self.initial_tokens.index(tokenizer.sot)
|
1010 |
-
self.inference = PyTorchInference(model, len(self.initial_tokens))
|
1011 |
-
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
1012 |
-
self.decoder = BeamSearchDecoder(options.beam_size, tokenizer.eot, self.inference, options.patience) if options.beam_size is not None else GreedyDecoder(options.temperature, tokenizer.eot)
|
1013 |
-
|
1014 |
-
self.logit_filters = []
|
1015 |
-
|
1016 |
-
if self.options.suppress_blank: self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
|
1017 |
-
if self.options.suppress_tokens: self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
|
1018 |
-
|
1019 |
-
if not options.without_timestamps:
|
1020 |
-
max_initial_timestamp_index = None
|
1021 |
-
if options.max_initial_timestamp: max_initial_timestamp_index = round(self.options.max_initial_timestamp / (CHUNK_LENGTH / model.dims.n_audio_ctx))
|
1022 |
-
self.logit_filters.append(ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index))
|
1023 |
-
|
1024 |
-
def _verify_options(self, options):
|
1025 |
-
if options.beam_size is not None and options.best_of is not None: raise ValueError
|
1026 |
-
if options.temperature == 0 and options.best_of is not None: raise ValueError
|
1027 |
-
if options.patience is not None and options.beam_size is None: raise ValueError
|
1028 |
-
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1): raise ValueError
|
1029 |
-
|
1030 |
-
return options
|
1031 |
-
|
1032 |
-
def _get_initial_tokens(self):
|
1033 |
-
tokens = list(self.sot_sequence)
|
1034 |
-
|
1035 |
-
if prefix := self.options.prefix:
|
1036 |
-
prefix_tokens = (self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix)
|
1037 |
-
if self.sample_len is not None: prefix_tokens = prefix_tokens[-(self.n_ctx // 2 - self.sample_len):]
|
1038 |
-
tokens = tokens + prefix_tokens
|
1039 |
-
|
1040 |
-
if prompt := self.options.prompt: tokens = ([self.tokenizer.sot_prev] + (self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt)[-(self.n_ctx // 2 - 1) :] + tokens)
|
1041 |
-
|
1042 |
-
return tuple(tokens)
|
1043 |
-
|
1044 |
-
def _get_suppress_tokens(self):
|
1045 |
-
suppress_tokens = self.options.suppress_tokens
|
1046 |
-
if isinstance(suppress_tokens, str): suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
|
1047 |
-
|
1048 |
-
if -1 in suppress_tokens:
|
1049 |
-
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
1050 |
-
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
|
1051 |
-
elif suppress_tokens is None or len(suppress_tokens) == 0: suppress_tokens = []
|
1052 |
-
else: assert isinstance(suppress_tokens, list)
|
1053 |
-
|
1054 |
-
suppress_tokens.extend([self.tokenizer.transcribe, self.tokenizer.translate, self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm])
|
1055 |
-
|
1056 |
-
if self.tokenizer.no_speech is not None: suppress_tokens.append(self.tokenizer.no_speech)
|
1057 |
-
return tuple(sorted(set(suppress_tokens)))
|
1058 |
-
|
1059 |
-
def _get_audio_features(self, mel):
|
1060 |
-
if self.options.fp16: mel = mel.half()
|
1061 |
-
|
1062 |
-
audio_features = mel if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state) else self.model.encoder(mel)
|
1063 |
-
if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32): return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
|
1064 |
-
|
1065 |
-
return audio_features
|
1066 |
-
|
1067 |
-
def _detect_language(self, audio_features, tokens):
|
1068 |
-
languages = [self.options.language] * audio_features.shape[0]
|
1069 |
-
lang_probs = None
|
1070 |
-
|
1071 |
-
if self.options.language is None or self.options.task == "lang_id":
|
1072 |
-
lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
|
1073 |
-
languages = [max(probs, key=probs.get) for probs in lang_probs]
|
1074 |
-
|
1075 |
-
if self.options.language is None: tokens[:, self.sot_index + 1] = lang_tokens
|
1076 |
-
|
1077 |
-
return languages, lang_probs
|
1078 |
-
|
1079 |
-
def _main_loop(self, audio_features, tokens):
|
1080 |
-
n_batch = tokens.shape[0]
|
1081 |
-
sum_logprobs = torch.zeros(n_batch, device=audio_features.device)
|
1082 |
-
no_speech_probs = [np.nan] * n_batch
|
1083 |
-
|
1084 |
-
try:
|
1085 |
-
for i in range(self.sample_len):
|
1086 |
-
logits = self.inference.logits(tokens, audio_features)
|
1087 |
-
|
1088 |
-
if (i == 0 and self.tokenizer.no_speech is not None):
|
1089 |
-
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
1090 |
-
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
1091 |
-
|
1092 |
-
logits = logits[:, -1]
|
1093 |
-
for logit_filter in self.logit_filters:
|
1094 |
-
logit_filter.apply(logits, tokens)
|
1095 |
-
|
1096 |
-
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
|
1097 |
-
if completed or tokens.shape[-1] > self.n_ctx: break
|
1098 |
-
finally:
|
1099 |
-
self.inference.cleanup_caching()
|
1100 |
-
|
1101 |
-
return tokens, sum_logprobs, no_speech_probs
|
1102 |
-
|
1103 |
-
@torch.no_grad()
|
1104 |
-
def run(self, mel):
|
1105 |
-
self.decoder.reset()
|
1106 |
-
tokenizer = self.tokenizer
|
1107 |
-
n_audio = mel.shape[0]
|
1108 |
-
|
1109 |
-
audio_features = self._get_audio_features(mel)
|
1110 |
-
tokens = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
|
1111 |
-
|
1112 |
-
languages, language_probs = self._detect_language(audio_features, tokens)
|
1113 |
-
if self.options.task == "lang_id": return [DecodingResult(audio_features=features, language=language, language_probs=probs) for features, language, probs in zip(audio_features, languages, language_probs)]
|
1114 |
-
|
1115 |
-
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
1116 |
-
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
|
1117 |
-
|
1118 |
-
audio_features = audio_features[:: self.n_group]
|
1119 |
-
no_speech_probs = no_speech_probs[:: self.n_group]
|
1120 |
-
|
1121 |
-
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
|
1122 |
-
|
1123 |
-
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
1124 |
-
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
1125 |
-
|
1126 |
-
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
1127 |
-
tokens = [[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens]
|
1128 |
-
|
1129 |
-
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
1130 |
-
tokens = [t[i].tolist() for i, t in zip(selected, tokens)]
|
1131 |
-
|
1132 |
-
fields = ([tokenizer.decode(t).strip() for t in tokens], languages, tokens, audio_features, [lp / (len(t) + 1) for t, lp in zip(tokens, [lp[i] for i, lp in zip(selected, sum_logprobs)])], no_speech_probs)
|
1133 |
-
if len(set(map(len, fields))) != 1: raise RuntimeError
|
1134 |
-
|
1135 |
-
return [DecodingResult(audio_features=features, language=language, tokens=tokens, text=text, avg_logprob=avg_logprob, no_speech_prob=no_speech_prob, temperature=self.options.temperature, compression_ratio=compression_ratio(text)) for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)]
|
1136 |
-
|
1137 |
-
class DecodingResult:
|
1138 |
-
def __init__(self, audio_features, language, language_probs = None, tokens = None, text = "", avg_logprob = np.nan, no_speech_prob = np.nan, temperature = np.nan, compression_ratio = np.nan):
|
1139 |
-
self.audio_features = audio_features
|
1140 |
-
self.language = language
|
1141 |
-
self.language_probs = language_probs if language_probs is not None else {}
|
1142 |
-
self.tokens = tokens if tokens is not None else []
|
1143 |
-
self.text = text
|
1144 |
-
self.avg_logprob = avg_logprob
|
1145 |
-
self.no_speech_prob = no_speech_prob
|
1146 |
-
self.temperature = temperature
|
1147 |
-
self.compression_ratio = compression_ratio
|
1148 |
-
|
1149 |
-
class Tokenizer:
|
1150 |
-
def __init__(self, encoding_name, num_languages = 2, language = None, task = None, sot_sequence = ()):
|
1151 |
-
self.encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
1152 |
-
self.num_languages = num_languages
|
1153 |
-
self.language = language
|
1154 |
-
self.task = task
|
1155 |
-
self.sot_sequence = sot_sequence
|
1156 |
-
self.special_tokens = {}
|
1157 |
-
|
1158 |
-
for special in self.encoding.special_tokens_set:
|
1159 |
-
special_token = self.encoding.encode_single_token(special)
|
1160 |
-
self.special_tokens[special] = special_token
|
1161 |
-
|
1162 |
-
sot = self.special_tokens["<|startoftranscript|>"]
|
1163 |
-
langs = tuple(LANGUAGES.keys())[: self.num_languages]
|
1164 |
-
sot_sequence = [sot]
|
1165 |
-
|
1166 |
-
if self.language is not None: sot_sequence.append(sot + 1 + langs.index(self.language))
|
1167 |
-
if self.task is not None: sot_sequence.append(self.special_tokens["<|transcribe|>"] if self.task == "transcribe" else self.special_tokens["<|translate|>"])
|
1168 |
-
|
1169 |
-
self.sot_sequence = tuple(sot_sequence)
|
1170 |
-
|
1171 |
-
def encode(self, text, **kwargs):
|
1172 |
-
return self.encoding.encode(text, **kwargs)
|
1173 |
-
|
1174 |
-
def decode(self, token_ids, **kwargs):
|
1175 |
-
return self.encoding.decode([t for t in token_ids if t < self.timestamp_begin], **kwargs)
|
1176 |
-
|
1177 |
-
def decode_with_timestamps(self, token_ids, **kwargs):
|
1178 |
-
return self.encoding.decode(token_ids, **kwargs)
|
1179 |
-
|
1180 |
-
@cached_property
|
1181 |
-
def eot(self):
|
1182 |
-
return self.encoding.eot_token
|
1183 |
-
|
1184 |
-
@cached_property
|
1185 |
-
def transcribe(self):
|
1186 |
-
return self.special_tokens["<|transcribe|>"]
|
1187 |
-
|
1188 |
-
@cached_property
|
1189 |
-
def translate(self):
|
1190 |
-
return self.special_tokens["<|translate|>"]
|
1191 |
-
|
1192 |
-
@cached_property
|
1193 |
-
def sot(self):
|
1194 |
-
return self.special_tokens["<|startoftranscript|>"]
|
1195 |
-
|
1196 |
-
@cached_property
|
1197 |
-
def sot_lm(self):
|
1198 |
-
return self.special_tokens["<|startoflm|>"]
|
1199 |
-
|
1200 |
-
@cached_property
|
1201 |
-
def sot_prev(self):
|
1202 |
-
return self.special_tokens["<|startofprev|>"]
|
1203 |
-
|
1204 |
-
@cached_property
|
1205 |
-
def no_speech(self):
|
1206 |
-
return self.special_tokens["<|nospeech|>"]
|
1207 |
-
|
1208 |
-
@cached_property
|
1209 |
-
def no_timestamps(self):
|
1210 |
-
return self.special_tokens["<|notimestamps|>"]
|
1211 |
-
|
1212 |
-
@cached_property
|
1213 |
-
def timestamp_begin(self):
|
1214 |
-
return self.special_tokens["<|0.00|>"]
|
1215 |
-
|
1216 |
-
@cached_property
|
1217 |
-
def language_token(self):
|
1218 |
-
if self.language is None: raise ValueError
|
1219 |
-
return self.to_language_token(self.language)
|
1220 |
-
|
1221 |
-
def to_language_token(self, language):
|
1222 |
-
if token := self.special_tokens.get(f"<|{language}|>", None): return token
|
1223 |
-
raise KeyError
|
1224 |
-
|
1225 |
-
@cached_property
|
1226 |
-
def all_language_tokens(self):
|
1227 |
-
result = []
|
1228 |
-
for token, token_id in self.special_tokens.items():
|
1229 |
-
if token.strip("<|>") in LANGUAGES: result.append(token_id)
|
1230 |
-
|
1231 |
-
return tuple(result)[: self.num_languages]
|
1232 |
-
|
1233 |
-
@cached_property
|
1234 |
-
def all_language_codes(self):
|
1235 |
-
return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
|
1236 |
-
|
1237 |
-
@cached_property
|
1238 |
-
def sot_sequence_including_notimestamps(self):
|
1239 |
-
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
1240 |
-
|
1241 |
-
@cached_property
|
1242 |
-
def non_speech_tokens(self):
|
1243 |
-
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
1244 |
-
symbols += ("<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split())
|
1245 |
-
|
1246 |
-
miscellaneous = set("♩♪♫♬♭♮♯")
|
1247 |
-
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
1248 |
-
|
1249 |
-
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
|
1250 |
-
for symbol in symbols + list(miscellaneous):
|
1251 |
-
for tokens in [self.encoding.encode(symbol), self.encoding.encode(" " + symbol)]:
|
1252 |
-
if len(tokens) == 1 or symbol in miscellaneous: result.add(tokens[0])
|
1253 |
-
|
1254 |
-
return tuple(sorted(result))
|
1255 |
-
|
1256 |
-
def split_to_word_tokens(self, tokens):
|
1257 |
-
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}: return self.split_tokens_on_unicode(tokens)
|
1258 |
-
return self.split_tokens_on_spaces(tokens)
|
1259 |
-
|
1260 |
-
def split_tokens_on_unicode(self, tokens):
|
1261 |
-
replacement_char = "\ufffd"
|
1262 |
-
|
1263 |
-
words, word_tokens, current_tokens = [], [], []
|
1264 |
-
unicode_offset = 0
|
1265 |
-
|
1266 |
-
for token in tokens:
|
1267 |
-
current_tokens.append(token)
|
1268 |
-
decoded = self.decode_with_timestamps(current_tokens)
|
1269 |
-
|
1270 |
-
if (replacement_char not in decoded or self.decode_with_timestamps(tokens)[unicode_offset + decoded.index(replacement_char)] == replacement_char):
|
1271 |
-
words.append(decoded)
|
1272 |
-
word_tokens.append(current_tokens)
|
1273 |
-
current_tokens = []
|
1274 |
-
unicode_offset += len(decoded)
|
1275 |
-
|
1276 |
-
return words, word_tokens
|
1277 |
-
|
1278 |
-
def split_tokens_on_spaces(self, tokens):
|
1279 |
-
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
|
1280 |
-
words, word_tokens = [], []
|
1281 |
-
|
1282 |
-
for subword, subword_tokens in zip(subwords, subword_tokens_list):
|
1283 |
-
if (subword_tokens[0] >= self.eot) or (subword.startswith(" ")) or (subword.strip() in string.punctuation) or len(words) == 0:
|
1284 |
-
words.append(subword)
|
1285 |
-
word_tokens.append(subword_tokens)
|
1286 |
-
else:
|
1287 |
-
words[-1] = words[-1] + subword
|
1288 |
-
word_tokens[-1].extend(subword_tokens)
|
1289 |
-
|
1290 |
-
return words, word_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/utils.py
DELETED
@@ -1,240 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
import sys
|
4 |
-
import codecs
|
5 |
-
import librosa
|
6 |
-
import logging
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
-
import soundfile as sf
|
10 |
-
|
11 |
-
from pydub import AudioSegment
|
12 |
-
|
13 |
-
sys.path.append(os.getcwd())
|
14 |
-
|
15 |
-
from main.tools import huggingface
|
16 |
-
from main.configs.config import Config
|
17 |
-
|
18 |
-
for l in ["httpx", "httpcore"]:
|
19 |
-
logging.getLogger(l).setLevel(logging.ERROR)
|
20 |
-
|
21 |
-
translations = Config().translations
|
22 |
-
|
23 |
-
|
24 |
-
def check_predictors(method, f0_onnx=False):
|
25 |
-
if f0_onnx and method not in ["harvest", "dio"]: method += "-onnx"
|
26 |
-
|
27 |
-
def download(predictors):
|
28 |
-
if not os.path.exists(os.path.join("assets", "models", "predictors", predictors)): huggingface.HF_download_file(codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/cerqvpgbef/", "rot13") + predictors, os.path.join("assets", "models", "predictors", predictors))
|
29 |
-
|
30 |
-
model_dict = {**dict.fromkeys(["rmvpe", "rmvpe-legacy"], "rmvpe.pt"), **dict.fromkeys(["rmvpe-onnx", "rmvpe-legacy-onnx"], "rmvpe.onnx"), **dict.fromkeys(["fcpe"], "fcpe.pt"), **dict.fromkeys(["fcpe-legacy"], "fcpe_legacy.pt"), **dict.fromkeys(["fcpe-onnx"], "fcpe.onnx"), **dict.fromkeys(["fcpe-legacy-onnx"], "fcpe_legacy.onnx"), **dict.fromkeys(["crepe-full", "mangio-crepe-full"], "crepe_full.pth"), **dict.fromkeys(["crepe-full-onnx", "mangio-crepe-full-onnx"], "crepe_full.onnx"), **dict.fromkeys(["crepe-large", "mangio-crepe-large"], "crepe_large.pth"), **dict.fromkeys(["crepe-large-onnx", "mangio-crepe-large-onnx"], "crepe_large.onnx"), **dict.fromkeys(["crepe-medium", "mangio-crepe-medium"], "crepe_medium.pth"), **dict.fromkeys(["crepe-medium-onnx", "mangio-crepe-medium-onnx"], "crepe_medium.onnx"), **dict.fromkeys(["crepe-small", "mangio-crepe-small"], "crepe_small.pth"), **dict.fromkeys(["crepe-small-onnx", "mangio-crepe-small-onnx"], "crepe_small.onnx"), **dict.fromkeys(["crepe-tiny", "mangio-crepe-tiny"], "crepe_tiny.pth"), **dict.fromkeys(["crepe-tiny-onnx", "mangio-crepe-tiny-onnx"], "crepe_tiny.onnx"), **dict.fromkeys(["harvest", "dio"], "world.pth")}
|
31 |
-
|
32 |
-
if "hybrid" in method:
|
33 |
-
methods_str = re.search("hybrid\[(.+)\]", method)
|
34 |
-
if methods_str: methods = [method.strip() for method in methods_str.group(1).split("+")]
|
35 |
-
|
36 |
-
for method in methods:
|
37 |
-
if method in model_dict: download(model_dict[method])
|
38 |
-
elif method in model_dict: download(model_dict[method])
|
39 |
-
|
40 |
-
def check_embedders(hubert, embedders_mode="fairseq"):
|
41 |
-
huggingface_url = codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/rzorqqref/", "rot13")
|
42 |
-
|
43 |
-
if hubert in ["contentvec_base", "hubert_base", "japanese_hubert_base", "korean_hubert_base", "chinese_hubert_base", "portuguese_hubert_base"]:
|
44 |
-
if embedders_mode == "fairseq": hubert += ".pt"
|
45 |
-
elif embedders_mode == "onnx": hubert += ".onnx"
|
46 |
-
|
47 |
-
model_path = os.path.join("assets", "models", "embedders", hubert)
|
48 |
-
|
49 |
-
if embedders_mode == "fairseq":
|
50 |
-
if not os.path.exists(model_path): huggingface.HF_download_file("".join([huggingface_url, "fairseq/", hubert]), model_path)
|
51 |
-
elif embedders_mode == "onnx":
|
52 |
-
if not os.path.exists(model_path): huggingface.HF_download_file("".join([huggingface_url, "onnx/", hubert]), model_path)
|
53 |
-
elif embedders_mode == "transformers":
|
54 |
-
bin_file = os.path.join(model_path, "model.safetensors")
|
55 |
-
config_file = os.path.join(model_path, "config.json")
|
56 |
-
|
57 |
-
os.makedirs(model_path, exist_ok=True)
|
58 |
-
|
59 |
-
if not os.path.exists(bin_file): huggingface.HF_download_file("".join([huggingface_url, "transformers/", hubert, "/model.safetensors"]), bin_file)
|
60 |
-
if not os.path.exists(config_file): huggingface.HF_download_file("".join([huggingface_url, "transformers/", hubert, "/config.json"]), config_file)
|
61 |
-
else: raise ValueError(translations["option_not_valid"])
|
62 |
-
|
63 |
-
def check_spk_diarization(model_size):
|
64 |
-
whisper_model = os.path.join("assets", "models", "speaker_diarization", "models", f"{model_size}.pt")
|
65 |
-
if not os.path.exists(whisper_model): huggingface.HF_download_file("".join([codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/fcrnxre_qvnevmngvba/", "rot13"), model_size, ".pt"]), whisper_model)
|
66 |
-
|
67 |
-
speechbrain_path = os.path.join("assets", "models", "speaker_diarization", "models", "speechbrain")
|
68 |
-
if not os.path.exists(speechbrain_path): os.makedirs(speechbrain_path, exist_ok=True)
|
69 |
-
|
70 |
-
for f in ["classifier.ckpt", "config.json", "embedding_model.ckpt", "hyperparams.yaml", "mean_var_norm_emb.ckpt"]:
|
71 |
-
speechbrain_model = os.path.join(speechbrain_path, f)
|
72 |
-
if not os.path.exists(speechbrain_model): huggingface.HF_download_file(codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/fcrnxre_qvnevmngvba/fcrrpuoenva/", "rot13") + f, speechbrain_model)
|
73 |
-
|
74 |
-
def check_audioldm2(model):
|
75 |
-
for f in ["feature_extractor", "language_model", "projection_model", "scheduler", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "unet", "vae", "vocoder"]:
|
76 |
-
folder_path = os.path.join("assets", "models", "audioldm2", model, f)
|
77 |
-
if not os.path.exists(folder_path): os.makedirs(folder_path, exist_ok=True)
|
78 |
-
|
79 |
-
for f in ["feature_extractor/preprocessor_config.json","language_model/config.json","language_model/model.safetensors","model_index.json","projection_model/config.json","projection_model/diffusion_pytorch_model.safetensors","scheduler/scheduler_config.json","text_encoder/config.json","text_encoder/model.safetensors","text_encoder_2/config.json","text_encoder_2/model.safetensors","tokenizer/merges.txt","tokenizer/special_tokens_map.json","tokenizer/tokenizer.json","tokenizer/tokenizer_config.json","tokenizer/vocab.json","tokenizer_2/special_tokens_map.json","tokenizer_2/spiece.model","tokenizer_2/tokenizer.json","tokenizer_2/tokenizer_config.json","unet/config.json","unet/diffusion_pytorch_model.safetensors","vae/config.json","vae/diffusion_pytorch_model.safetensors","vocoder/config.json","vocoder/model.safetensors"]:
|
80 |
-
model_path = os.path.join("assets", "models", "audioldm2", model, f)
|
81 |
-
if not os.path.exists(model_path): huggingface.HF_download_file("".join([codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/nhqvbyqz/", "rot13"), model, "/", f]), model_path)
|
82 |
-
|
83 |
-
def load_audio(logger, file, sample_rate=16000, formant_shifting=False, formant_qfrency=0.8, formant_timbre=0.8):
|
84 |
-
try:
|
85 |
-
file = file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
|
86 |
-
if not os.path.isfile(file): raise FileNotFoundError(translations["not_found"].format(name=file))
|
87 |
-
|
88 |
-
try:
|
89 |
-
logger.debug(translations['read_sf'])
|
90 |
-
audio, sr = sf.read(file, dtype=np.float32)
|
91 |
-
except:
|
92 |
-
logger.debug(translations['read_librosa'])
|
93 |
-
audio, sr = librosa.load(file, sr=None)
|
94 |
-
|
95 |
-
if len(audio.shape) > 1: audio = librosa.to_mono(audio.T)
|
96 |
-
if sr != sample_rate: audio = librosa.resample(audio, orig_sr=sr, target_sr=sample_rate, res_type="soxr_vhq")
|
97 |
-
|
98 |
-
if formant_shifting:
|
99 |
-
from main.library.algorithm.stftpitchshift import StftPitchShift
|
100 |
-
|
101 |
-
pitchshifter = StftPitchShift(1024, 32, sample_rate)
|
102 |
-
audio = pitchshifter.shiftpitch(audio, factors=1, quefrency=formant_qfrency * 1e-3, distortion=formant_timbre)
|
103 |
-
except Exception as e:
|
104 |
-
raise RuntimeError(f"{translations['errors_loading_audio']}: {e}")
|
105 |
-
|
106 |
-
return audio.flatten()
|
107 |
-
|
108 |
-
def pydub_convert(audio):
|
109 |
-
samples = np.frombuffer(audio.raw_data, dtype=np.int16)
|
110 |
-
if samples.dtype != np.int16: samples = (samples * 32767).astype(np.int16)
|
111 |
-
return AudioSegment(samples.tobytes(), frame_rate=audio.frame_rate, sample_width=samples.dtype.itemsize, channels=audio.channels)
|
112 |
-
|
113 |
-
def pydub_load(input_path):
|
114 |
-
try:
|
115 |
-
if input_path.endswith(".wav"): audio = AudioSegment.from_wav(input_path)
|
116 |
-
elif input_path.endswith(".mp3"): audio = AudioSegment.from_mp3(input_path)
|
117 |
-
elif input_path.endswith(".ogg"): audio = AudioSegment.from_ogg(input_path)
|
118 |
-
else: audio = AudioSegment.from_file(input_path)
|
119 |
-
except:
|
120 |
-
audio = AudioSegment.from_file(input_path)
|
121 |
-
|
122 |
-
return audio
|
123 |
-
|
124 |
-
def load_embedders_model(embedder_model, embedders_mode="fairseq", providers=None):
|
125 |
-
if embedders_mode == "fairseq": embedder_model += ".pt"
|
126 |
-
elif embedders_mode == "onnx": embedder_model += ".onnx"
|
127 |
-
|
128 |
-
embedder_model_path = os.path.join("assets", "models", "embedders", embedder_model)
|
129 |
-
if not os.path.exists(embedder_model_path): raise FileNotFoundError(f"{translations['not_found'].format(name=translations['model'])}: {embedder_model}")
|
130 |
-
|
131 |
-
try:
|
132 |
-
if embedders_mode == "fairseq":
|
133 |
-
from main.library.architectures import fairseq
|
134 |
-
|
135 |
-
models, saved_cfg, _ = fairseq.load_model(embedder_model_path)
|
136 |
-
embed_suffix = ".pt"
|
137 |
-
hubert_model = models[0]
|
138 |
-
elif embedders_mode == "onnx":
|
139 |
-
import onnxruntime
|
140 |
-
|
141 |
-
sess_options = onnxruntime.SessionOptions()
|
142 |
-
sess_options.log_severity_level = 3
|
143 |
-
embed_suffix, saved_cfg = ".onnx", None
|
144 |
-
hubert_model = onnxruntime.InferenceSession(embedder_model_path, sess_options=sess_options, providers=providers)
|
145 |
-
elif embedders_mode == "transformers":
|
146 |
-
from torch import nn
|
147 |
-
from transformers import HubertModel
|
148 |
-
|
149 |
-
class HubertModelWithFinalProj(HubertModel):
|
150 |
-
def __init__(self, config):
|
151 |
-
super().__init__(config)
|
152 |
-
self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
153 |
-
|
154 |
-
embed_suffix, saved_cfg = ".safetensors", None
|
155 |
-
hubert_model = HubertModelWithFinalProj.from_pretrained(embedder_model_path)
|
156 |
-
else: raise ValueError(translations["option_not_valid"])
|
157 |
-
except Exception as e:
|
158 |
-
raise RuntimeError(translations["read_model_error"].format(e=e))
|
159 |
-
|
160 |
-
return hubert_model, saved_cfg, embed_suffix
|
161 |
-
|
162 |
-
def cut(audio, sr, db_thresh=-60, min_interval=250):
|
163 |
-
from main.inference.preprocess import Slicer, get_rms
|
164 |
-
|
165 |
-
class Slicer2(Slicer):
|
166 |
-
def slice2(self, waveform):
|
167 |
-
samples = waveform.mean(axis=0) if len(waveform.shape) > 1 else waveform
|
168 |
-
|
169 |
-
if samples.shape[0] <= self.min_length: return [(waveform, 0, samples.shape[0])]
|
170 |
-
rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
|
171 |
-
|
172 |
-
sil_tags = []
|
173 |
-
silence_start, clip_start = None, 0
|
174 |
-
|
175 |
-
for i, rms in enumerate(rms_list):
|
176 |
-
if rms < self.threshold:
|
177 |
-
if silence_start is None: silence_start = i
|
178 |
-
continue
|
179 |
-
|
180 |
-
if silence_start is None: continue
|
181 |
-
|
182 |
-
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
|
183 |
-
need_slice_middle = (i - silence_start >= self.min_interval and i - clip_start >= self.min_length)
|
184 |
-
|
185 |
-
if not is_leading_silence and not need_slice_middle:
|
186 |
-
silence_start = None
|
187 |
-
continue
|
188 |
-
|
189 |
-
if i - silence_start <= self.max_sil_kept:
|
190 |
-
pos = rms_list[silence_start : i + 1].argmin() + silence_start
|
191 |
-
sil_tags.append((0, pos) if silence_start == 0 else (pos, pos))
|
192 |
-
clip_start = pos
|
193 |
-
elif i - silence_start <= self.max_sil_kept * 2:
|
194 |
-
pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
|
195 |
-
pos += i - self.max_sil_kept
|
196 |
-
|
197 |
-
pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
|
198 |
-
|
199 |
-
if silence_start == 0:
|
200 |
-
sil_tags.append((0, pos_r))
|
201 |
-
clip_start = pos_r
|
202 |
-
else:
|
203 |
-
sil_tags.append((min((rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start), pos), max(pos_r, pos)))
|
204 |
-
clip_start = max(pos_r, pos)
|
205 |
-
else:
|
206 |
-
pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
|
207 |
-
sil_tags.append((0, pos_r) if silence_start == 0 else ((rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start), pos_r))
|
208 |
-
clip_start = pos_r
|
209 |
-
|
210 |
-
silence_start = None
|
211 |
-
|
212 |
-
total_frames = rms_list.shape[0]
|
213 |
-
if (silence_start is not None and total_frames - silence_start >= self.min_interval): sil_tags.append((rms_list[silence_start : min(total_frames, silence_start + self.max_sil_kept) + 1].argmin() + silence_start, total_frames + 1))
|
214 |
-
|
215 |
-
if not sil_tags: return [(waveform, 0, samples.shape[-1])]
|
216 |
-
else:
|
217 |
-
chunks = []
|
218 |
-
if sil_tags[0][0] > 0: chunks.append((self._apply_slice(waveform, 0, sil_tags[0][0]), 0, sil_tags[0][0] * self.hop_size))
|
219 |
-
|
220 |
-
for i in range(len(sil_tags) - 1):
|
221 |
-
chunks.append((self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]), sil_tags[i][1] * self.hop_size, sil_tags[i + 1][0] * self.hop_size))
|
222 |
-
|
223 |
-
if sil_tags[-1][1] < total_frames: chunks.append((self._apply_slice(waveform, sil_tags[-1][1], total_frames), sil_tags[-1][1] * self.hop_size, samples.shape[-1]))
|
224 |
-
return chunks
|
225 |
-
|
226 |
-
slicer = Slicer2(sr=sr, threshold=db_thresh, min_interval=min_interval)
|
227 |
-
return slicer.slice2(audio)
|
228 |
-
|
229 |
-
def restore(segments, total_len, dtype=np.float32):
|
230 |
-
out = []
|
231 |
-
last_end = 0
|
232 |
-
|
233 |
-
for start, end, processed_seg in segments:
|
234 |
-
if start > last_end: out.append(np.zeros(start - last_end, dtype=dtype))
|
235 |
-
|
236 |
-
out.append(processed_seg)
|
237 |
-
last_end = end
|
238 |
-
|
239 |
-
if last_end < total_len: out.append(np.zeros(total_len - last_end, dtype=dtype))
|
240 |
-
return np.concatenate(out, axis=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|