sam2ai commited on
Commit
6de3e11
1 Parent(s): f6a9722

Synced repo using 'sync_with_huggingface' Github Action

Browse files
Dockerfile ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04
2
+
3
+ # Use Python 3.11 for better Python perf
4
+ # Update the package lists and install necessary dependencies
5
+ RUN apt-get update && apt-get install -y \
6
+ software-properties-common \
7
+ && add-apt-repository -y ppa:deadsnakes/ppa \
8
+ && apt-get update \
9
+ && apt-get install -y python3.11 python3.11-dev
10
+
11
+ # Set Python 3.11 as the default version (for python3)
12
+ RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1
13
+
14
+ # Download get-pip.py script
15
+ RUN apt install curl -y
16
+ RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
17
+
18
+ # Install pip for Python 3.11
19
+ RUN python3 get-pip.py
20
+
21
+ # Verify Python and pip versions
22
+ RUN python3 --version && pip3.11 --version
23
+
24
+ # Set pip3.11 as the default pip command
25
+ RUN update-alternatives --install /usr/bin/pip3 pip3 /usr/local/lib/python3.11/dist-packages/pip 1
26
+
27
+ ENV PYTHONUNBUFFERED=1
28
+
29
+ # Install necessary dependencies
30
+ # RUN apt-get update && \
31
+ # apt-get install -y python3-pip
32
+
33
+ # Set the working directory. /app is mounted to the container with -v,
34
+ # but we want to have the right cwd for uvicorn command below
35
+ RUN mkdir /app
36
+ # WORKDIR /app
37
+
38
+ # # Copy the app code and requirements filed
39
+ # COPY . /app
40
+ # COPY requirements.txt .
41
+ # WORKDIR $PYSETUP_PATH
42
+ COPY ./requirements.txt /app
43
+
44
+
45
+ COPY ./utils /app/utils
46
+ COPY ./static /app/static
47
+ COPY ./templates /app/templates
48
+ COPY ./infer_server.py /app/infer_server.py
49
+ COPY ./download.py /app/download.py
50
+
51
+ WORKDIR /app
52
+
53
+
54
+ # Install the app dependencies
55
+ # RUN pip3 install -r requirements.txt
56
+
57
+ RUN --mount=type=cache,target=/root/.cache/pip \
58
+ pip3 install -r requirements.txt
59
+
60
+ # Expose the FastAPI port
61
+ EXPOSE 5001
62
+
63
+ # Start the FastAPI app using Uvicorn web server
64
+ # CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "14000", "--limit-concurrency", "1000"]
65
+ RUN python3 download.py
66
+
67
+ CMD ["python3", "infer_server.py", "--host=0.0.0.0", "--port=5001", "--model_path=models/sam2ai/whisper-odia-small-finetune-int8-ct2", "--num_workers=2"]
68
+
69
+
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
configs/augmentation.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "type": "resample",
4
+ "params": {
5
+ "new_sample_rates": [8000, 32000, 44100]
6
+ },
7
+ "prob": 0.0
8
+ },
9
+ {
10
+ "type": "noise",
11
+ "params": {
12
+ "min_snr_dB": 10,
13
+ "max_snr_dB": 50,
14
+ "noise_dir": "dataset/noise"
15
+ },
16
+ "prob": 0.2
17
+ },
18
+ {
19
+ "type": "speed",
20
+ "params": {
21
+ "min_speed_rate": 0.9,
22
+ "max_speed_rate": 1.1,
23
+ "num_rates": 3
24
+ },
25
+ "prob": 0.5
26
+ },
27
+ {
28
+ "type": "shift",
29
+ "params": {
30
+ "min_shift_ms": -5,
31
+ "max_shift_ms": 5
32
+ },
33
+ "prob": 0.0
34
+ },
35
+ {
36
+ "type": "volume",
37
+ "params": {
38
+ "min_gain_dBFS": -15,
39
+ "max_gain_dBFS": 15
40
+ },
41
+ "prob": 0.5
42
+ }
43
+ ]
dataset/test.mp3 ADDED
Binary file (61.7 kB). View file
 
download.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import requests
3
+ import os
4
+ from tqdm import tqdm
5
+
6
+ def download_file(url, path):
7
+ response = requests.get(url, stream=True)
8
+ total_size_in_bytes = int(response.headers.get('content-length', 0))
9
+ block_size = 1024 #1 Kbyte
10
+ progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
11
+
12
+ with open(path, 'wb') as file:
13
+ for data in response.iter_content(block_size):
14
+ progress_bar.update(len(data))
15
+ file.write(data)
16
+
17
+ progress_bar.close()
18
+
19
+ def download_model(model_name, destination_folder="models"):
20
+ # Define the base URL and headers for the Hugging Face API
21
+ base_url = f"https://huggingface.co/{model_name}/resolve/main"
22
+ headers = {"User-Agent": "Hugging Face Python"}
23
+
24
+ # Send a GET request to the Hugging Face API to get a list of all files
25
+ response = requests.get(f"https://huggingface.co/api/models/{model_name}", headers=headers)
26
+ response.raise_for_status()
27
+
28
+ # Extract the list of files from the response JSON
29
+ files_to_download = [file["rfilename"] for file in response.json()["siblings"]]
30
+
31
+ # Ensure the directory exists
32
+ os.makedirs(f"{destination_folder}/{model_name}", exist_ok=True)
33
+
34
+ # Download each file
35
+ for file in files_to_download:
36
+ print(f"Downloading {file}...")
37
+ download_file(f"{base_url}/{file}", f"{destination_folder}/{model_name}/{file}")
38
+
39
+ if __name__ == "__main__":
40
+ # parser = argparse.ArgumentParser()
41
+ # parser.add_argument("model_name", type=str, default="sam2ai/whisper-odia-small-finetune-int8-ct2", help="Name of the model to download.")
42
+ # args = parser.parse_args()
43
+
44
+ download_model("sam2ai/whisper-odia-small-finetune-int8-ct2")
evaluation.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import functools
3
+ import gc
4
+ import os
5
+
6
+ import evaluate
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import DataLoader
10
+ from tqdm import tqdm
11
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor
12
+
13
+ from utils.data_utils import DataCollatorSpeechSeq2SeqWithPadding, remove_punctuation, to_simple
14
+ from utils.reader import CustomDataset
15
+ from utils.utils import print_arguments, add_arguments
16
+
17
+ parser = argparse.ArgumentParser(description=__doc__)
18
+ add_arg = functools.partial(add_arguments, argparser=parser)
19
+ add_arg("test_data", type=str, default="dataset/test.json", help="测试集的路径")
20
+ add_arg("model_path", type=str, default="models/whisper-tiny-finetune", help="合并模型的路径,或者是huggingface上模型的名称")
21
+ add_arg("batch_size", type=int, default=16, help="评估的batch size")
22
+ add_arg("num_workers", type=int, default=8, help="读取数据的线程数量")
23
+ add_arg("language", type=str, default="Chinese", help="设置语言,可全称也可简写,如果为None则评估的是多语言")
24
+ add_arg("remove_pun", type=bool, default=True, help="是否移除标点符号")
25
+ add_arg("to_simple", type=bool, default=True, help="是否转为简体中文")
26
+ add_arg("timestamps", type=bool, default=False, help="评估时是否使用时间戳数据")
27
+ add_arg("min_audio_len", type=float, default=0.5, help="最小的音频长度,单位秒")
28
+ add_arg("max_audio_len", type=float, default=30, help="最大的音频长度,单位秒")
29
+ add_arg("local_files_only", type=bool, default=True, help="是否只在本地加载模型,不尝试下载")
30
+ add_arg("task", type=str, default="transcribe", choices=['transcribe', 'translate'], help="模型的任务")
31
+ add_arg("metric", type=str, default="cer", choices=['cer', 'wer'], help="评估方式")
32
+ args = parser.parse_args()
33
+ print_arguments(args)
34
+
35
+ # 判断模型路径是否合法
36
+ assert 'openai' == os.path.dirname(args.model_path) or os.path.exists(args.model_path), \
37
+ f"模型文件{args.model_path}不存在,请检查是否已经成功合并模型,或者是否为huggingface存在模型"
38
+ # 获取Whisper的数据处理器,这个包含了特征提取器、tokenizer
39
+ processor = WhisperProcessor.from_pretrained(args.model_path,
40
+ language=args.language,
41
+ task=args.task,
42
+ no_timestamps=not args.timestamps,
43
+ local_files_only=args.local_files_only)
44
+ forced_decoder_ids = processor.get_decoder_prompt_ids()
45
+ # 获取模型
46
+ model = WhisperForConditionalGeneration.from_pretrained(args.model_path,
47
+ device_map="auto",
48
+ local_files_only=args.local_files_only)
49
+ model.eval()
50
+
51
+ # 获取测试数据
52
+ test_dataset = CustomDataset(data_list_path=args.test_data,
53
+ processor=processor,
54
+ timestamps=args.timestamps,
55
+ min_duration=args.min_audio_len,
56
+ max_duration=args.max_audio_len)
57
+ print(f"测试数据:{len(test_dataset)}")
58
+
59
+ # 数据padding器
60
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
61
+ eval_dataloader = DataLoader(test_dataset, batch_size=args.batch_size,
62
+ num_workers=args.num_workers, collate_fn=data_collator)
63
+
64
+ # 获取评估方法
65
+ metric = evaluate.load(args.metric)
66
+
67
+ # 开始评估
68
+ for step, batch in enumerate(tqdm(eval_dataloader)):
69
+ with torch.cuda.amp.autocast():
70
+ with torch.no_grad():
71
+ generated_tokens = (
72
+ model.generate(
73
+ input_features=batch["input_features"].cuda(),
74
+ decoder_input_ids=batch["labels"][:, :4].cuda(),
75
+ forced_decoder_ids=forced_decoder_ids,
76
+ max_new_tokens=255).cpu().numpy())
77
+ labels = batch["labels"].cpu().numpy()
78
+ labels = np.where(labels != -100, labels, processor.tokenizer.pad_token_id)
79
+ # 将预测和实际的token转换为文本
80
+ decoded_preds = processor.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
81
+ decoded_labels = processor.tokenizer.batch_decode(labels, skip_special_tokens=True)
82
+ # 删除标点符号
83
+ if args.remove_pun:
84
+ decoded_preds = remove_punctuation(decoded_preds)
85
+ decoded_labels = remove_punctuation(decoded_labels)
86
+ # 将繁体中文总成简体中文
87
+ if args.to_simple:
88
+ decoded_preds = to_simple(decoded_preds)
89
+ decoded_labels = to_simple(decoded_labels)
90
+ metric.add_batch(predictions=decoded_preds, references=decoded_labels)
91
+ # 删除计算的记录
92
+ del generated_tokens, labels, batch
93
+ gc.collect()
94
+ # 计算评估结果
95
+ m = metric.compute()
96
+ print(f"评估结果:{args.metric}={round(m, 5)}")
finetune.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import functools
3
+ import os
4
+ import platform
5
+
6
+ import torch
7
+ from peft import LoraConfig, get_peft_model, AdaLoraConfig, PeftModel, prepare_model_for_kbit_training
8
+ from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, WhisperForConditionalGeneration, WhisperProcessor
9
+
10
+ from utils.callback import SavePeftModelCallback
11
+ from utils.data_utils import DataCollatorSpeechSeq2SeqWithPadding
12
+ from utils.model_utils import load_from_checkpoint
13
+ from utils.reader import CustomDataset
14
+ from utils.utils import print_arguments, make_inputs_require_grad, add_arguments
15
+
16
+ parser = argparse.ArgumentParser(description=__doc__)
17
+ add_arg = functools.partial(add_arguments, argparser=parser)
18
+ add_arg("train_data", type=str, default="dataset/train.json", help="")
19
+ add_arg("test_data", type=str, default="dataset/test.json", help="")
20
+ add_arg("base_model", type=str, default="openai/whisper-tiny", help="Whisper")
21
+ add_arg("output_dir", type=str, default="output/", help="")
22
+ add_arg("warmup_steps", type=int, default=50, help="")
23
+ add_arg("logging_steps", type=int, default=100, help="")
24
+ add_arg("eval_steps", type=int, default=1000, help="")
25
+ add_arg("save_steps", type=int, default=1000, help="")
26
+ add_arg("num_workers", type=int, default=8, help="")
27
+ add_arg("learning_rate", type=float, default=1e-3, help="")
28
+ add_arg("min_audio_len", type=float, default=0.5, help="")
29
+ add_arg("max_audio_len", type=float, default=30, help="")
30
+ add_arg("use_adalora", type=bool, default=True, help="AdaLora/Lora")
31
+ add_arg("fp16", type=bool, default=True, help="fp16")
32
+ add_arg("use_8bit", type=bool, default=False, help="8 bit")
33
+ add_arg("timestamps", type=bool, default=False, help="")
34
+ add_arg("local_files_only", type=bool, default=False, help="")
35
+ add_arg("num_train_epochs", type=int, default=3, help="")
36
+ add_arg("language", type=str, default="bn", help="")
37
+ add_arg("task", type=str, default="transcribe", choices=['transcribe', 'translate'], help="模型的任务")
38
+ add_arg("augment_config_path", type=str, default=None, help="")
39
+ add_arg("resume_from_checkpoint", type=str, default=None, help="")
40
+ add_arg("per_device_train_batch_size", type=int, default=8, help="batch size")
41
+ add_arg("per_device_eval_batch_size", type=int, default=8, help="batch size")
42
+ add_arg("gradient_accumulation_steps", type=int, default=1, help="")
43
+
44
+ args = parser.parse_args()
45
+ print_arguments(args)
46
+
47
+
48
+ # Whisper tokenizer
49
+ processor = WhisperProcessor.from_pretrained(args.base_model,
50
+ language=args.language,
51
+ task=args.task,
52
+ no_timestamps=not args.timestamps,
53
+ local_files_only=args.local_files_only)
54
+
55
+ #
56
+ train_dataset = CustomDataset(data_list_path=args.train_data,
57
+ processor=processor,
58
+ language=args.language,
59
+ timestamps=args.timestamps,
60
+ min_duration=args.min_audio_len,
61
+ max_duration=args.max_audio_len,
62
+ augment_config_path=args.augment_config_path)
63
+ test_dataset = CustomDataset(data_list_path=args.test_data,
64
+ processor=processor,
65
+ language=args.language,
66
+ timestamps=args.timestamps,
67
+ min_duration=args.min_audio_len,
68
+ max_duration=args.max_audio_len)
69
+ print(f"len train - {len(train_dataset)} test len - {len(test_dataset)}")
70
+
71
+ # padding
72
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
73
+
74
+ # Whisper
75
+ device_map = "auto"
76
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
77
+ ddp = world_size != 1
78
+ if ddp:
79
+ device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
80
+
81
+ #
82
+ model = WhisperForConditionalGeneration.from_pretrained(args.base_model,
83
+ load_in_8bit=args.use_8bit,
84
+ device_map=device_map,
85
+ local_files_only=args.local_files_only)
86
+ model.config.forced_decoder_ids = None
87
+ model.config.suppress_tokens = []
88
+ #
89
+ model = prepare_model_for_kbit_training(model)
90
+ # forward,req grad
91
+ model.model.encoder.conv1.register_forward_hook(make_inputs_require_grad)
92
+
93
+ print('加载LoRA模块...')
94
+ if args.resume_from_checkpoint:
95
+ #
96
+ print("Loading adapters from checkpoint.")
97
+ model = PeftModel.from_pretrained(model, args.resume_from_checkpoint, is_trainable=True)
98
+ else:
99
+ print(f'adding LoRA modules...')
100
+ target_modules = ["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"]
101
+ print(target_modules)
102
+ if args.use_adalora:
103
+ config = AdaLoraConfig(init_r=12, target_r=4, beta1=0.85, beta2=0.85, tinit=200, tfinal=1000, deltaT=10,
104
+ lora_alpha=32, lora_dropout=0.1, orth_reg_weight=0.5, target_modules=target_modules)
105
+ else:
106
+ config = LoraConfig(r=32, lora_alpha=64, target_modules=target_modules, lora_dropout=0.05, bias="none")
107
+ model = get_peft_model(model, config)
108
+
109
+ output_dir = os.path.join(args.output_dir, os.path.basename(args.base_model))
110
+ #
111
+ training_args = \
112
+ Seq2SeqTrainingArguments(output_dir=output_dir, # Directory to save checkpoints
113
+ per_device_train_batch_size=args.per_device_train_batch_size, # Training batch_size size
114
+ per_device_eval_batch_size=args.per_device_eval_batch_size, # Eval batch_size
115
+ gradient_accumulation_steps=args.gradient_accumulation_steps, # Cumulative steps of training gradient
116
+ learning_rate=args.learning_rate, # learning rate size
117
+ warmup_steps=args.warmup_steps, # Warm-up steps
118
+ num_train_epochs=args.num_train_epochs, # epochs
119
+ save_strategy="steps", #
120
+ evaluation_strategy="steps", #
121
+ load_best_model_at_end=True, #
122
+ fp16=args.fp16, #
123
+ report_to=["tensorboard"], # tensorboard
124
+ save_steps=args.save_steps, #
125
+ eval_steps=args.eval_steps, #
126
+ save_total_limit=5, #
127
+ optim='adamw_torch', #
128
+ ddp_find_unused_parameters=False if ddp else None, #
129
+ dataloader_num_workers=args.num_workers, #
130
+ logging_steps=args.logging_steps, #
131
+ remove_unused_columns=False, #
132
+ label_names=["labels"]) #
133
+
134
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
135
+ print('=' * 90)
136
+ model.print_trainable_parameters()
137
+ print('=' * 90)
138
+
139
+ # Pytorch2.0
140
+ if torch.__version__ >= "2" and platform.system().lower() == 'windows':
141
+ model = torch.compile(model)
142
+
143
+ #
144
+ trainer = Seq2SeqTrainer(args=training_args,
145
+ model=model,
146
+ train_dataset=train_dataset,
147
+ eval_dataset=test_dataset,
148
+ data_collator=data_collator,
149
+ tokenizer=processor.feature_extractor,
150
+ callbacks=[SavePeftModelCallback])
151
+ model.config.use_cache = False
152
+ trainer._load_from_checkpoint = load_from_checkpoint
153
+
154
+ #
155
+ trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
156
+
157
+ #
158
+ trainer.save_state()
159
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
160
+ model.save_pretrained(os.path.join(output_dir, "checkpoint-final"))
infer_ct2.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import functools
3
+ import os
4
+
5
+ from faster_whisper import WhisperModel
6
+
7
+ from utils.utils import print_arguments, add_arguments
8
+
9
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
10
+ parser = argparse.ArgumentParser(description=__doc__)
11
+ add_arg = functools.partial(add_arguments, argparser=parser)
12
+ add_arg("audio_path", type=str, default="dataset/test.wav", help="")
13
+ add_arg("model_path", type=str, default="models/whisper-tiny-finetune-ct2", help="")
14
+ add_arg("language", type=str, default="zh", help="")
15
+ add_arg("use_gpu", type=bool, default=True, help="")
16
+ add_arg("use_int8", type=bool, default=False, help="int8")
17
+ add_arg("beam_size", type=int, default=10, help="")
18
+ add_arg("num_workers", type=int, default=1, help="")
19
+ add_arg("vad_filter", type=bool, default=False, help="")
20
+ add_arg("local_files_only", type=bool, default=True, help="")
21
+ args = parser.parse_args()
22
+ print_arguments(args)
23
+
24
+ #
25
+ assert os.path.exists(args.model_path), f"{args.model_path}"
26
+ #
27
+ if args.use_gpu:
28
+ if not args.use_int8:
29
+ model = WhisperModel(args.model_path, device="cuda", compute_type="float16", num_workers=args.num_workers,
30
+ local_files_only=args.local_files_only)
31
+ else:
32
+ model = WhisperModel(args.model_path, device="cuda", compute_type="int8_float16", num_workers=args.num_workers,
33
+ local_files_only=args.local_files_only)
34
+ else:
35
+ model = WhisperModel(args.model_path, device="cpu", compute_type="int8", num_workers=args.num_workers,
36
+ local_files_only=args.local_files_only)
37
+ #
38
+ _, _ = model.transcribe("dataset/test.wav", beam_size=5)
39
+
40
+
41
+ #
42
+ segments, info = model.transcribe(args.audio_path, beam_size=args.beam_size, language=args.language,
43
+ vad_filter=args.vad_filter)
44
+ for segment in segments:
45
+ text = segment.text
46
+ print(f"[{round(segment.start, 2)} - {round(segment.end, 2)}]:{text}\n")
infer_server.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import asyncio
3
+ import functools
4
+ import json
5
+ import os
6
+ from io import BytesIO
7
+
8
+ import uvicorn
9
+ from fastapi import FastAPI, BackgroundTasks, File, Body, UploadFile, Request
10
+ from fastapi.responses import StreamingResponse
11
+ from faster_whisper import WhisperModel
12
+ from starlette.staticfiles import StaticFiles
13
+ from starlette.templating import Jinja2Templates
14
+ from zhconv import convert
15
+
16
+ from utils.data_utils import remove_punctuation
17
+ from utils.utils import add_arguments, print_arguments
18
+
19
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
20
+
21
+ parser = argparse.ArgumentParser(description=__doc__)
22
+ add_arg = functools.partial(add_arguments, argparser=parser)
23
+
24
+ add_arg("host", type=str, default="0.0.0.0", help="")
25
+ add_arg("port", type=int, default=5000, help="")
26
+ add_arg("model_path", type=str, default="models/sam2ai/whisper-odia-small-finetune-int8-ct2", help="")
27
+ add_arg("use_gpu", type=bool, default=False, help="")
28
+ add_arg("use_int8", type=bool, default=True, help="")
29
+ add_arg("beam_size", type=int, default=10, help="")
30
+ add_arg("num_workers", type=int, default=2, help="")
31
+ add_arg("vad_filter", type=bool, default=True, help="")
32
+ add_arg("local_files_only", type=bool, default=True, help="")
33
+ args = parser.parse_args()
34
+ print_arguments(args)
35
+
36
+ #
37
+ assert os.path.exists(args.model_path), f"{args.model_path}"
38
+ #
39
+ if args.use_gpu:
40
+ if not args.use_int8:
41
+ model = WhisperModel(args.model_path, device="cuda", compute_type="float16",
42
+ num_workers=args.num_workers, local_files_only=args.local_files_only)
43
+ else:
44
+ model = WhisperModel(args.model_path, device="cuda",
45
+ compute_type="int8_float16", num_workers=args.num_workers,
46
+ local_files_only=args.local_files_only)
47
+ else:
48
+ model = WhisperModel(args.model_path, device="cpu",
49
+ compute_type="int8", num_workers=args.num_workers,
50
+ local_files_only=args.local_files_only)
51
+
52
+ #
53
+ # _, _ = model.transcribe("dataset/test.wav", beam_size=5)
54
+
55
+ app = FastAPI(title="")
56
+ app.mount('/static', StaticFiles(directory='static'), name='static')
57
+ templates = Jinja2Templates(directory="templates")
58
+ model_semaphore = None
59
+
60
+
61
+ def release_model_semaphore():
62
+ model_semaphore.release()
63
+
64
+
65
+ def recognition(file: File, to_simple: int,
66
+ remove_pun: int, language: str = "ory",
67
+ task: str = "transcribe"
68
+ ):
69
+
70
+ segments, info = model.transcribe(file, beam_size=10, task=task, language=language, vad_filter=args.vad_filter)
71
+ for segment in segments:
72
+ text = segment.text
73
+ if to_simple == 1:
74
+ # text = convert(text, '')
75
+ pass
76
+ if remove_pun == 1:
77
+ # text = remove_punctuation(text)
78
+ pass
79
+ ret = {"result": text, "start": round(segment.start, 2), "end": round(segment.end, 2)}
80
+ #
81
+ yield json.dumps(ret).encode() + b"\0"
82
+
83
+
84
+ @app.post("/recognition_stream")
85
+ async def api_recognition_stream(
86
+ to_simple: int = Body(1, description="", embed=True),
87
+ remove_pun: int = Body(0, description="", embed=True),
88
+ language: str = Body("ory", description="", embed=True),
89
+ task: str = Body("transcribe", description="", embed=True),
90
+ audio: UploadFile = File(..., description="")
91
+ ):
92
+
93
+ global model_semaphore
94
+ if language == "None": language = None
95
+ if model_semaphore is None:
96
+ model_semaphore = asyncio.Semaphore(5)
97
+ await model_semaphore.acquire()
98
+ contents = await audio.read()
99
+ data = BytesIO(contents)
100
+ generator = recognition(
101
+ file=data, to_simple=to_simple,
102
+ remove_pun=remove_pun, language=language,
103
+ task=task
104
+ )
105
+ background_tasks = BackgroundTasks()
106
+ background_tasks.add_task(release_model_semaphore)
107
+ return StreamingResponse(generator, background=background_tasks)
108
+
109
+
110
+ @app.post("/recognition")
111
+ async def api_recognition(
112
+ to_simple: int = Body(1, description="", embed=True),
113
+ remove_pun: int = Body(0, description="", embed=True),
114
+ language: str = Body("ory", description="", embed=True),
115
+ task: str = Body("transcribe", description="", embed=True),
116
+ audio: UploadFile = File(..., description="")
117
+ ):
118
+
119
+ if language == "None":language=None
120
+ contents = await audio.read()
121
+ data = BytesIO(contents)
122
+ generator = recognition(
123
+ file=data, to_simple=to_simple,
124
+ remove_pun=remove_pun, language=language,
125
+ task=task
126
+ )
127
+ results = []
128
+ for output in generator:
129
+ output = json.loads(output[:-1].decode("utf-8"))
130
+ results.append(output)
131
+ ret = {"results": results, "code": 0}
132
+ return ret
133
+
134
+
135
+ @app.get("/")
136
+ async def index(request: Request):
137
+ return templates.TemplateResponse(
138
+ "index.html", {"request": request, "id": id}
139
+ )
140
+
141
+
142
+ if __name__ == '__main__':
143
+ uvicorn.run(app, host=args.host, port=args.port)
infer_tfs.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import functools
3
+
4
+ import librosa
5
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor
6
+
7
+ from utils.utils import print_arguments, add_arguments
8
+
9
+ parser = argparse.ArgumentParser(description=__doc__)
10
+ add_arg = functools.partial(add_arguments, argparser=parser)
11
+ add_arg("audio_path", type=str, default="dataset/test.wav", help="")
12
+ add_arg("model_path", type=str, default="models/whisper-tiny-finetune", help="")
13
+ add_arg("language", type=str, default="Oriya", help="")
14
+ add_arg("task", type=str, default="transcribe", choices=['transcribe', 'translate'], help="")
15
+ add_arg("local_files_only", type=bool, default=True, help="")
16
+ args = parser.parse_args()
17
+ print_arguments(args)
18
+
19
+ # Whisper
20
+ processor = WhisperProcessor.from_pretrained(args.model_path,
21
+ language=args.language,
22
+ task=args.task,
23
+ local_files_only=args.local_files_only)
24
+ forced_decoder_ids = processor.get_decoder_prompt_ids(language=args.language, task=args.task)
25
+
26
+ #
27
+ model = WhisperForConditionalGeneration.from_pretrained(args.model_path,
28
+ device_map="auto",
29
+ local_files_only=args.local_files_only).half()
30
+ model.eval()
31
+
32
+ #
33
+ sample, sr = librosa.load(args.audio_path, sr=16000)
34
+ duration = sample.shape[-1]/sr
35
+ assert duration < 30, f"This program is only suitable for inferring audio less than 30 seconds, the current audio {duration} seconds, use another inference program!"
36
+
37
+ #
38
+ input_features = processor(sample, sampling_rate=sr, return_tensors="pt", do_normalize=True).input_features.cuda().half()
39
+ #
40
+ predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids, max_new_tokens=256)
41
+ #
42
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
43
+ print(f"result :{transcription}")
merge_lora.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import functools
3
+ import os
4
+
5
+ from transformers import WhisperForConditionalGeneration, WhisperFeatureExtractor, WhisperTokenizerFast,\
6
+ WhisperProcessor
7
+ from peft import PeftModel, PeftConfig
8
+ from utils.utils import print_arguments, add_arguments
9
+
10
+ parser = argparse.ArgumentParser(description=__doc__)
11
+ add_arg = functools.partial(add_arguments, argparser=parser)
12
+ add_arg("lora_model", type=str, default="output/whisper-tiny/checkpoint-best/", help="")
13
+ add_arg('output_dir', type=str, default='models/', help="")
14
+ add_arg("local_files_only", type=bool, default=False, help="")
15
+ args = parser.parse_args()
16
+ print_arguments(args)
17
+
18
+ #
19
+ assert os.path.exists(args.lora_model), f"{args.lora_model}"
20
+ # Lora
21
+ peft_config = PeftConfig.from_pretrained(args.lora_model)
22
+ # Whisper
23
+ base_model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path, device_map={"": "cpu"},
24
+ local_files_only=args.local_files_only)
25
+ # Lora
26
+ model = PeftModel.from_pretrained(base_model, args.lora_model, local_files_only=args.local_files_only)
27
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(peft_config.base_model_name_or_path,
28
+ local_files_only=args.local_files_only)
29
+ tokenizer = WhisperTokenizerFast.from_pretrained(peft_config.base_model_name_or_path,
30
+ local_files_only=args.local_files_only)
31
+ processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path,
32
+ local_files_only=args.local_files_only)
33
+
34
+ #
35
+ model = model.merge_and_unload()
36
+ model.train(False)
37
+
38
+ #
39
+ save_directory = os.path.join(args.output_dir, f'{os.path.basename(peft_config.base_model_name_or_path)}-finetune')
40
+ os.makedirs(save_directory, exist_ok=True)
41
+
42
+ #
43
+ model.save_pretrained(save_directory)
44
+ feature_extractor.save_pretrained(save_directory)
45
+ tokenizer.save_pretrained(save_directory)
46
+ processor.save_pretrained(save_directory)
47
+ print(f'model saved directory :{save_directory}')
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy>=1.23.1
2
+ soundfile>=0.12.1
3
+ librosa>=0.10.0
4
+ dataclasses>=0.6
5
+ transformers>=4.31.0
6
+ bitsandbytes>=0.41.0
7
+ soundfile>=0.12.1
8
+ datasets>=2.11.0
9
+ evaluate>=0.4.0
10
+ faster-whisper>=0.7.0
11
+ jiwer>=2.5.1
12
+ peft>=0.4.0
13
+ accelerate>=0.21.0
14
+ zhconv>=1.4.2
15
+ tqdm>=4.62.1
16
+ soundcard>=0.4.2
17
+ uvicorn>=0.21.1
18
+ fastapi>=0.95.1
19
+ starlette>=0.26.1
20
+ tensorboardX>=2.2
21
+ python-multipart
run.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 finetune.py --base_model=openai/whisper-tiny --use_8bit=False --per_device_train_batch_size=8 --per_device_eval_batch_size=8 --gradient_accumulation_steps=1
4
+ CUDA_VISIBLE_DEVICES=0 python merge_lora.py --lora_model=output/whisper-tiny/checkpoint-final
5
+
6
+ CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 finetune.py --base_model=openai/whisper-base --use_8bit=False --per_device_train_batch_size=8 --per_device_eval_batch_size=8 --gradient_accumulation_steps=1
7
+ CUDA_VISIBLE_DEVICES=0 python merge_lora.py --lora_model=output/whisper-base/checkpoint-final
8
+
9
+ CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 finetune.py --base_model=openai/whisper-small --use_8bit=True --per_device_train_batch_size=8 --per_device_eval_batch_size=8 --gradient_accumulation_steps=1
10
+ CUDA_VISIBLE_DEVICES=0 python merge_lora.py --lora_model=output/whisper-small/checkpoint-final
11
+
12
+ CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 finetune.py --base_model=openai/whisper-medium --use_8bit=True --per_device_train_batch_size=4 --per_device_eval_batch_size=2 --gradient_accumulation_steps=2
13
+ CUDA_VISIBLE_DEVICES=0 python merge_lora.py --lora_model=output/whisper-medium/checkpoint-final
14
+
15
+ CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 finetune.py --base_model=openai/whisper-large-v2 --use_8bit=True --per_device_train_batch_size=2 --per_device_eval_batch_size=2 --gradient_accumulation_steps=4
16
+ CUDA_VISIBLE_DEVICES=0 python merge_lora.py --lora_model=output/whisper-large-v2/checkpoint-final
17
+
18
+ CUDA_VISIBLE_DEVICES=0 python evaluation.py --model_path=models/whisper-tiny-finetune
19
+ CUDA_VISIBLE_DEVICES=0 python evaluation.py --model_path=models/whisper-base-finetune
20
+ CUDA_VISIBLE_DEVICES=0 python evaluation.py --model_path=models/whisper-small-finetune
21
+ CUDA_VISIBLE_DEVICES=0 python evaluation.py --model_path=models/whisper-medium-finetune
22
+ CUDA_VISIBLE_DEVICES=0 python evaluation.py --model_path=models/whisper-large-v2-finetune
static/index.css ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ * {
2
+ box-sizing: border-box;
3
+ }
4
+
5
+ body {
6
+ font-family: "Helvetica Neue", "Roboto", sans-serif;
7
+ background-color: #f2f2f2;
8
+ margin: 0;
9
+ padding: 0;
10
+ }
11
+
12
+ #header {
13
+ background-color: #fff;
14
+ color: #333;
15
+ display: flex;
16
+ justify-content: center;
17
+ align-items: center;
18
+ height: 80px;
19
+ }
20
+
21
+ h1 {
22
+ font-size: 36px;
23
+ margin: 0;
24
+ }
25
+
26
+ #content {
27
+ background-color: #fff;
28
+ border-radius: 10px;
29
+ box-shadow: 0 2px 10px rgba(0, 0, 0, 0.2);
30
+ margin: 50px auto;
31
+ max-width: 800px;
32
+ padding: 20px;
33
+ }
34
+
35
+ #content div {
36
+ display: flex;
37
+ flex-wrap: wrap;
38
+ justify-content: space-between;
39
+ margin-bottom: 20px;
40
+ }
41
+
42
+ #content a {
43
+ background-color: #fff;
44
+ border-radius: 5px;
45
+ box-shadow: 0 2px 10px rgba(0, 0, 0, 0.2);
46
+ color: #333;
47
+ padding: 10px;
48
+ text-align: center;
49
+ text-decoration: none;
50
+ transition: background-color 0.2s;
51
+ width: 20%;
52
+ }
53
+
54
+ #content a:hover {
55
+ background-color: #f2f2f2;
56
+ }
57
+
58
+ #content img {
59
+ cursor: pointer;
60
+ height: 50px;
61
+ transition: transform 0.2s;
62
+ width: 50px;
63
+ }
64
+
65
+ #content img:hover {
66
+ transform: scale(1.1);
67
+ }
68
+
69
+ #result {
70
+ background-color: #fff;
71
+ border-radius: 5px;
72
+ box-shadow: 0 2px 10px rgba(0, 0, 0, 0.2);
73
+ padding: 10px;
74
+ }
75
+
76
+ #result textarea {
77
+ border: none;
78
+ border-radius: 5px;
79
+ font-size: 16px;
80
+ height: 100px;
81
+ margin-top: 10px;
82
+ padding: 10px;
83
+ resize: none;
84
+ width: 100%;
85
+ }
86
+
87
+ /* #llm_result {
88
+ background-color: #fff;
89
+ border-radius: 5px;
90
+ box-shadow: 0 2px 10px rgba(0, 0, 0, 0.2);
91
+ padding: 10px;
92
+ }
93
+
94
+ #llm_result textarea {
95
+ border: none;
96
+ border-radius: 5px;
97
+ font-size: 16px;
98
+ height: 100px;
99
+ margin-top: 10px;
100
+ padding: 10px;
101
+ resize: none;
102
+ width: 100%;
103
+ } */
104
+
105
+ @media only screen and (max-width: 600px) {
106
+ #content a {
107
+ width: 100%;
108
+ }
109
+ }
static/record.js ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //兼容
2
+ window.URL = window.URL || window.webkitURL;
3
+ //获取计算机的设备:摄像头或者录音设备
4
+ navigator.getUserMedia = navigator.getUserMedia || navigator.webkitGetUserMedia || navigator.mozGetUserMedia || navigator.msGetUserMedia;
5
+
6
+ var HZRecorder = function (stream, config) {
7
+ config = config || {};
8
+ config.sampleBits = config.sampleBits || 16; //采样数位 8, 16
9
+ config.sampleRate = config.sampleRate || 16000; //采样率 16000
10
+
11
+ //创建一个音频环境对象
12
+ var audioContext = window.AudioContext || window.webkitAudioContext;
13
+ var context = new audioContext();
14
+ var audioInput = context.createMediaStreamSource(stream);
15
+ // 第二个和第三个参数指的是输入和输出都是单声道,2是双声道。
16
+ var recorder = context.createScriptProcessor(4096, 2, 2);
17
+
18
+ var audioData = {
19
+ size: 0 //录音文件长度
20
+ , buffer: [] //录音缓存
21
+ , inputSampleRate: context.sampleRate //输入采样率
22
+ , inputSampleBits: 16 //输入采样数位 8, 16
23
+ , outputSampleRate: config.sampleRate //输出采样率
24
+ , outputSampleBits: config.sampleBits //输出采样数位 8, 16
25
+ , input: function (data) {
26
+ this.buffer.push(new Float32Array(data));
27
+ this.size += data.length;
28
+ }
29
+ , compress: function () { //合并压缩
30
+ //合并
31
+ var data = new Float32Array(this.size);
32
+ var offset = 0;
33
+ for (var i = 0; i < this.buffer.length; i++) {
34
+ data.set(this.buffer[i], offset);
35
+ offset += this.buffer[i].length;
36
+ }
37
+ //压缩
38
+ var compression = parseInt(this.inputSampleRate / this.outputSampleRate);
39
+ var length = data.length / compression;
40
+ var result = new Float32Array(length);
41
+ var index = 0, j = 0;
42
+ while (index < length) {
43
+ result[index] = data[j];
44
+ j += compression;
45
+ index++;
46
+ }
47
+ return result;
48
+ }
49
+ , encodeWAV: function () {
50
+ var sampleRate = Math.min(this.inputSampleRate, this.outputSampleRate);
51
+ var sampleBits = Math.min(this.inputSampleBits, this.outputSampleBits);
52
+ var bytes = this.compress();
53
+ var dataLength = bytes.length * (sampleBits / 8);
54
+ var buffer = new ArrayBuffer(44 + dataLength);
55
+ var data = new DataView(buffer);
56
+
57
+ var channelCount = 1;//单声道
58
+ var offset = 0;
59
+
60
+ var writeString = function (str) {
61
+ for (var i = 0; i < str.length; i++) {
62
+ data.setUint8(offset + i, str.charCodeAt(i));
63
+ }
64
+ }
65
+
66
+ // 资源交换文件标识符
67
+ writeString('RIFF');
68
+ offset += 4;
69
+ // 下个地址开始到文件尾总字节数,即文件大小-8
70
+ data.setUint32(offset, 36 + dataLength, true);
71
+ offset += 4;
72
+ // WAV文件标志
73
+ writeString('WAVE');
74
+ offset += 4;
75
+ // 波形格式标志
76
+ writeString('fmt ');
77
+ offset += 4;
78
+ // 过滤字节,一般为 0x10 = 16
79
+ data.setUint32(offset, 16, true);
80
+ offset += 4;
81
+ // 格式类别 (PCM形式采样数据)
82
+ data.setUint16(offset, 1, true);
83
+ offset += 2;
84
+ // 通道数
85
+ data.setUint16(offset, channelCount, true);
86
+ offset += 2;
87
+ // 采样率,每秒样本数,表示每个通道的播放速度
88
+ data.setUint32(offset, sampleRate, true);
89
+ offset += 4;
90
+ // 波形数据传输率 (每秒平均字节数) 单声道×每秒数据位数×每样本数据位/8
91
+ data.setUint32(offset, channelCount * sampleRate * (sampleBits / 8), true);
92
+ offset += 4;
93
+ // 快数据调整数 采样一次占用字节数 单声道×每样本的数据位数/8
94
+ data.setUint16(offset, channelCount * (sampleBits / 8), true);
95
+ offset += 2;
96
+ // 每样本数据位数
97
+ data.setUint16(offset, sampleBits, true);
98
+ offset += 2;
99
+ // 数据标识符
100
+ writeString('data');
101
+ offset += 4;
102
+ // 采样数据总数,即数据总大小-44
103
+ data.setUint32(offset, dataLength, true);
104
+ offset += 4;
105
+ // 写入采样数据
106
+ if (sampleBits === 8) {
107
+ for (var i = 0; i < bytes.length; i++, offset++) {
108
+ var s = Math.max(-1, Math.min(1, bytes[i]));
109
+ var val = s < 0 ? s * 0x8000 : s * 0x7FFF;
110
+ val = parseInt(255 / (65535 / (val + 32768)));
111
+ data.setInt8(offset, val, true);
112
+ }
113
+ } else {
114
+ for (var i = 0; i < bytes.length; i++, offset += 2) {
115
+ var s = Math.max(-1, Math.min(1, bytes[i]));
116
+ data.setInt16(offset, s < 0 ? s * 0x8000 : s * 0x7FFF, true);
117
+ }
118
+ }
119
+
120
+ return new Blob([data], {type: 'audio/wav'});
121
+ }
122
+ };
123
+
124
+ //开始录音
125
+ this.start = function () {
126
+ audioInput.connect(recorder);
127
+ recorder.connect(context.destination);
128
+ }
129
+
130
+ //停止
131
+ this.stop = function () {
132
+ recorder.disconnect();
133
+ }
134
+
135
+ //获取音频文件
136
+ this.getBlob = function () {
137
+ this.stop();
138
+ return audioData.encodeWAV();
139
+ }
140
+
141
+ //回放
142
+ this.play = function (audio) {
143
+ audio.src = window.URL.createObjectURL(this.getBlob());
144
+ }
145
+ //清除
146
+ this.clear = function () {
147
+ audioData.buffer = [];
148
+ audioData.size = 0;
149
+ }
150
+
151
+ //上传
152
+ this.upload = function (url, callback) {
153
+ var fd = new FormData();
154
+ // 上传的文件名和数据
155
+ fd.append("audio", this.getBlob());
156
+ var xhr = new XMLHttpRequest();
157
+ xhr.timeout = 60000
158
+ if (callback) {
159
+ xhr.upload.addEventListener("progress", function (e) {
160
+ callback('uploading', e);
161
+ }, false);
162
+ xhr.addEventListener("load", function (e) {
163
+ callback('ok', e);
164
+ }, false);
165
+ xhr.addEventListener("error", function (e) {
166
+ callback('error', e);
167
+ }, false);
168
+ xhr.addEventListener("abort", function (e) {
169
+ callback('cancel', e);
170
+ }, false);
171
+ }
172
+ xhr.open("POST", url);
173
+ xhr.send(fd);
174
+ }
175
+
176
+ //音频采集
177
+ recorder.onaudioprocess = function (e) {
178
+ audioData.input(e.inputBuffer.getChannelData(0));
179
+ //record(e.inputBuffer.getChannelData(0));
180
+ }
181
+
182
+ };
183
+ //抛出异常
184
+ HZRecorder.throwError = function (message) {
185
+ alert(message);
186
+ throw new function () {
187
+ this.toString = function () {
188
+ return message;
189
+ }
190
+ }
191
+ }
192
+ //是否支持录音
193
+ HZRecorder.canRecording = (navigator.getUserMedia != null);
194
+ //获取录音机
195
+ HZRecorder.get = function (callback, config) {
196
+ if (callback) {
197
+ if (navigator.getUserMedia) {
198
+ navigator.getUserMedia(
199
+ {audio: true} //只启用音频
200
+ , function (stream) {
201
+ var rec = new HZRecorder(stream, config);
202
+ callback(rec);
203
+ }
204
+ , function (error) {
205
+ switch (error.code || error.name) {
206
+ case 'PERMISSION_DENIED':
207
+ case 'PermissionDeniedError':
208
+ HZRecorder.throwError('用户拒绝提供信息。');
209
+ break;
210
+ case 'NOT_SUPPORTED_ERROR':
211
+ case 'NotSupportedError':
212
+ HZRecorder.throwError('浏览器不支持硬件设备。');
213
+ break;
214
+ case 'MANDATORY_UNSATISFIED_ERROR':
215
+ case 'MandatoryUnsatisfiedError':
216
+ HZRecorder.throwError('无法发现指定的硬件设备。');
217
+ break;
218
+ default:
219
+ HZRecorder.throwError('无法打开麦克风。异常信息:' + (error.code || error.name));
220
+ break;
221
+ }
222
+ });
223
+ } else {
224
+ window.alert('不是HTTPS协议或者localhost地址,不能使用录音功能!')
225
+ HZRecorder.throwErr('当前浏览器不支持录音功能。');
226
+ return;
227
+ }
228
+ }
229
+ };
static/record.png ADDED
static/recording.gif ADDED
templates/index.html ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <title>OdiaGenAI Speech Recognition</title>
6
+ <script type="text/javascript" src="/static/record.js"></script>
7
+ <link href="/static/index.css" rel="stylesheet" type="text/css"/>
8
+ </head>
9
+ <body>
10
+ <div id="header">
11
+ <h1>OdiaGenAI Speech Recognition</h1>
12
+ </div>
13
+ <div id="content">
14
+ <div>
15
+ <a id="upload" onclick="uploadAudioFile()" class="file">select audio file</a>
16
+ <a id="play_btn" onclick="uploadRecordAudio()" class="file">predict audio file</a>
17
+ <audio controls autoplay></audio>
18
+ <img id="record_btn" onclick="record()" src="/static/record.png" alt="record"/>
19
+ </div>
20
+ <div id="result">
21
+ <label for="result_p"></label><textarea id="result_p"></textarea>
22
+ </div>
23
+ <!-- <div id="llm_result">
24
+ <a id="llm_predict" onclick="uploadAudioFile()" class="file">generate text</a>
25
+ <label for="result_llm"></label><textarea id="result_llm"></textarea>
26
+ </div> -->
27
+ </div>
28
+ <script>
29
+ let is_recording = false;
30
+ let is_playing = false;
31
+ let host = location.origin;
32
+ let recorder;
33
+ let audio = document.querySelector('audio');
34
+ let textarea = document.getElementById('result_p')
35
+
36
+
37
+ function record() {
38
+ if (is_recording) {
39
+ is_recording = false;
40
+ stopRecording()
41
+ document.getElementById('record_btn').src = '/static/record.png'
42
+ startPlay();
43
+ stopPlay();
44
+ } else {
45
+ is_recording = true;
46
+ startRecording()
47
+ document.getElementById('record_btn').src = '/static/recording.gif'
48
+ }
49
+ }
50
+
51
+ function play() {
52
+ if (is_playing) {
53
+ is_playing = false;
54
+ stopPlay()
55
+ document.getElementById('play_btn').innerText = 'play audio'
56
+ } else {
57
+ is_playing = true;
58
+ startPlay()
59
+ document.getElementById('play_btn').innerText = 'Stop play'
60
+ }
61
+ }
62
+
63
+ function startRecording() {
64
+ HZRecorder.get(function (rec) {
65
+ recorder = rec;
66
+ recorder.start();
67
+ });
68
+ }
69
+
70
+ function stopRecording() {
71
+ recorder.stop();
72
+ }
73
+
74
+ function startPlay() {
75
+ recorder.play(audio);
76
+ }
77
+
78
+ function stopPlay() {
79
+ audio.pause();
80
+ }
81
+
82
+ function cancelAudio() {
83
+ recorder.stop();
84
+ recorder.clear();
85
+ }
86
+
87
+ function uploadRecordAudio() {
88
+ recorder.upload(location.origin + "/recognition", function (state, e) {
89
+ switch (state) {
90
+ case 'uploading':
91
+ const percentComplete = Math.round(e.loaded * 100 / e.total) + '%';
92
+ console.log(percentComplete);
93
+ break;
94
+ case 'ok':
95
+ console.log(e.target.responseText)
96
+ document.getElementById('result_p').innerHTML = e.target.responseText
97
+ break;
98
+ case 'error':
99
+ alert("upload failed");
100
+ break;
101
+ case 'cancel':
102
+ alert("upload canceled");
103
+ break;
104
+ }
105
+ });
106
+ }
107
+
108
+ //
109
+ function uploadAudioFile() {
110
+ const input = document.createElement("input");
111
+ input.type = "file";
112
+ input.accept = "audio/*,video/*";
113
+ input.click();
114
+ input.onchange = function () {
115
+ const file = input.files[0];
116
+ console.log(file)
117
+ audio.src = window.URL.createObjectURL(file);
118
+ stopPlay();
119
+ upload_file(host + "/recognition", file, function (state, e) {
120
+ switch (state) {
121
+ case 'uploading':
122
+ const percentComplete = Math.round(e.loaded * 100 / e.total) + '%';
123
+ console.log(percentComplete);
124
+ break;
125
+ case 'ok':
126
+ console.log(e.target.responseText)
127
+ textarea.innerText = e.target.responseText
128
+ break;
129
+ case 'error':
130
+ alert("upload failed");
131
+ break;
132
+ case 'cancel':
133
+ alert("upload canceled");
134
+ break;
135
+ }
136
+ });
137
+ }
138
+ }
139
+
140
+ //
141
+ upload_file = function (url, file, callback) {
142
+ const fd = new FormData();
143
+ //
144
+ fd.append("audio", file);
145
+ const xhr = new XMLHttpRequest();
146
+ xhr.timeout = 60000
147
+ if (callback) {
148
+ xhr.upload.addEventListener("progress", function (e) {
149
+ callback('uploading', e);
150
+ }, false);
151
+ xhr.addEventListener("load", function (e) {
152
+ callback('ok', e);
153
+ }, false);
154
+ xhr.addEventListener("error", function (e) {
155
+ callback('error', e);
156
+ }, false);
157
+ xhr.addEventListener("abort", function (e) {
158
+ callback('cancel', e);
159
+ }, false);
160
+ }
161
+ xhr.open("POST", url);
162
+ xhr.send(fd);
163
+ }
164
+ </script>
165
+
166
+ </body>
167
+ </html>
utils/__init__.py ADDED
File without changes
utils/binary.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import mmap
3
+
4
+ import struct
5
+
6
+ from tqdm import tqdm
7
+
8
+
9
+ class DatasetWriter(object):
10
+ def __init__(self, prefix):
11
+ #
12
+ self.data_file = open(prefix + '.data', 'wb')
13
+ self.header_file = open(prefix + '.header', 'wb')
14
+ self.data_sum = 0
15
+ self.offset = 0
16
+ self.header = ''
17
+
18
+ def add_data(self, data):
19
+ key = str(self.data_sum)
20
+ data = bytes(data, encoding="utf8")
21
+ #
22
+ self.data_file.write(struct.pack('I', len(key)))
23
+ self.data_file.write(key.encode('ascii'))
24
+ self.data_file.write(struct.pack('I', len(data)))
25
+ self.data_file.write(data)
26
+ #
27
+ self.offset += 4 + len(key) + 4
28
+ self.header = key + '\t' + str(self.offset) + '\t' + str(len(data)) + '\n'
29
+ self.header_file.write(self.header.encode('ascii'))
30
+ self.offset += len(data)
31
+ self.data_sum += 1
32
+
33
+ def close(self):
34
+ self.data_file.close()
35
+ self.header_file.close()
36
+
37
+
38
+ class DatasetReader(object):
39
+ def __init__(self, data_header_path, min_duration=0, max_duration=30):
40
+ self.keys = []
41
+ self.offset_dict = {}
42
+ self.fp = open(data_header_path.replace('.header', '.data'), 'rb')
43
+ self.m = mmap.mmap(self.fp.fileno(), 0, access=mmap.ACCESS_READ)
44
+ for line in tqdm(open(data_header_path, 'rb'), desc='读取数据列表'):
45
+ key, val_pos, val_len = line.split('\t'.encode('ascii'))
46
+ data = self.m[int(val_pos):int(val_pos) + int(val_len)]
47
+ data = str(data, encoding="utf-8")
48
+ data = json.loads(data)
49
+ #
50
+ if data["duration"] < min_duration:
51
+ continue
52
+ if max_duration != -1 and data["duration"] > max_duration:
53
+ continue
54
+ self.keys.append(key)
55
+ self.offset_dict[key] = (int(val_pos), int(val_len))
56
+
57
+ #
58
+ def get_data(self, key):
59
+ p = self.offset_dict.get(key, None)
60
+ if p is None:
61
+ return None
62
+ val_pos, val_len = p
63
+ data = self.m[val_pos:val_pos + val_len]
64
+ data = str(data, encoding="utf-8")
65
+ return json.loads(data)
66
+
67
+ #
68
+ def get_keys(self):
69
+ return self.keys
70
+
71
+ def __len__(self):
72
+ return len(self.keys)
utils/callback.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os
3
+ import shutil
4
+
5
+ from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
6
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
7
+
8
+
9
+ #
10
+ class SavePeftModelCallback(TrainerCallback):
11
+ def on_save(self,
12
+ args: TrainingArguments,
13
+ state: TrainerState,
14
+ control: TrainerControl,
15
+ **kwargs, ):
16
+ if args.local_rank == 0 or args.local_rank == -1:
17
+ #
18
+ checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
19
+ peft_model_dir = os.path.join(checkpoint_folder, "adapter_model")
20
+ kwargs["model"].save_pretrained(peft_model_dir)
21
+ peft_config_path = os.path.join(checkpoint_folder, "adapter_model/adapter_config.json")
22
+ peft_model_path = os.path.join(checkpoint_folder, "adapter_model/adapter_model.bin")
23
+ if not os.path.exists(peft_config_path):
24
+ os.remove(peft_config_path)
25
+ if not os.path.exists(peft_model_path):
26
+ os.remove(peft_model_path)
27
+ if os.path.exists(peft_model_dir):
28
+ shutil.rmtree(peft_model_dir)
29
+ #
30
+ best_checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-best")
31
+ #
32
+ if os.path.exists(state.best_model_checkpoint):
33
+ if os.path.exists(best_checkpoint_folder):
34
+ shutil.rmtree(best_checkpoint_folder)
35
+ shutil.copytree(state.best_model_checkpoint, best_checkpoint_folder)
36
+ print(f"{state.best_model_checkpoint}{state.best_metric}")
37
+ return control
utils/data_utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from dataclasses import dataclass
3
+ from typing import Any, List, Dict, Union
4
+
5
+ import torch
6
+ from zhconv import convert
7
+
8
+
9
+ # 删除标点符号
10
+ def remove_punctuation(text: str or List[str]):
11
+ punctuation = '!,.;:?、!,。;:?'
12
+ if isinstance(text, str):
13
+ text = re.sub(r'[{}]+'.format(punctuation), '', text).strip()
14
+ return text
15
+ elif isinstance(text, list):
16
+ result_text = []
17
+ for t in text:
18
+ t = re.sub(r'[{}]+'.format(punctuation), '', t).strip()
19
+ result_text.append(t)
20
+ return result_text
21
+ else:
22
+ raise Exception(f'不支持该类型{type(text)}')
23
+
24
+
25
+ # 将繁体中文总成简体中文
26
+ def to_simple(text: str or List[str]):
27
+ if isinstance(text, str):
28
+ text = convert(text, 'zh-cn')
29
+ return text
30
+ elif isinstance(text, list):
31
+ result_text = []
32
+ for t in text:
33
+ t = convert(t, 'zh-cn')
34
+ result_text.append(t)
35
+ return result_text
36
+ else:
37
+ raise Exception(f'不支持该类型{type(text)}')
38
+
39
+
40
+ @dataclass
41
+ class DataCollatorSpeechSeq2SeqWithPadding:
42
+ processor: Any
43
+
44
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
45
+ # split inputs and labels since they have to be of different lengths and need different padding methods
46
+ # first treat the audio inputs by simply returning torch tensors
47
+ input_features = [{"input_features": feature["input_features"][0]} for feature in features]
48
+ batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
49
+
50
+ # get the tokenized label sequences
51
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
52
+ # pad the labels to max length
53
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
54
+
55
+ # replace padding with -100 to ignore loss correctly
56
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
57
+
58
+ # if bos token is appended in previous tokenization step,
59
+ # cut bos token here as it's append later anyways
60
+ if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
61
+ labels = labels[:, 1:]
62
+
63
+ batch["labels"] = labels
64
+
65
+ return batch
utils/model_utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bitsandbytes as bnb
2
+ import torch
3
+ from transformers.trainer_pt_utils import LabelSmoother
4
+
5
+ IGNORE_TOKEN_ID = LabelSmoother.ignore_index
6
+
7
+
8
+ def find_all_linear_names(use_8bit, model):
9
+ cls = bnb.nn.Linear8bitLt if use_8bit else torch.nn.Linear
10
+ lora_module_names = set()
11
+ for name, module in model.named_modules():
12
+ if isinstance(module, cls):
13
+ names = name.split('.')
14
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
15
+ target_modules = list(lora_module_names)
16
+ return target_modules
17
+
18
+
19
+ def load_from_checkpoint(resume_from_checkpoint, model=None):
20
+ pass
utils/pun_predictor.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+
5
+ import numpy as np
6
+ import paddle.inference as paddle_infer
7
+ from paddlenlp.transformers import ErnieTokenizer
8
+
9
+
10
+ __all__ = ['PunctuationExecutor']
11
+
12
+
13
+ class PunctuationExecutor:
14
+ def __init__(self, model_dir, use_gpu=True, gpu_mem=500, num_threads=4):
15
+ # config
16
+ model_path = os.path.join(model_dir, 'model.pdmodel')
17
+ params_path = os.path.join(model_dir, 'model.pdiparams')
18
+ if not os.path.exists(model_path) or not os.path.exists(params_path):
19
+ raise Exception("{}{}".format(model_path, params_path))
20
+ self.config = paddle_infer.Config(model_path, params_path)
21
+ #
22
+ pretrained_token = 'ernie-1.0'
23
+ if os.path.exists(os.path.join(model_dir, 'info.json')):
24
+ with open(os.path.join(model_dir, 'info.json'), 'r', encoding='utf-8') as f:
25
+ data = json.load(f)
26
+ pretrained_token = data['pretrained_token']
27
+
28
+ if use_gpu:
29
+ self.config.enable_use_gpu(gpu_mem, 0)
30
+ else:
31
+ self.config.disable_gpu()
32
+ self.config.set_cpu_math_library_num_threads(num_threads)
33
+ # enable memory optim
34
+ self.config.enable_memory_optim()
35
+ self.config.disable_glog_info()
36
+
37
+ # config predictor
38
+ self.predictor = paddle_infer.create_predictor(self.config)
39
+
40
+ #
41
+ self.input_ids_handle = self.predictor.get_input_handle('input_ids')
42
+ self.token_type_ids_handle = self.predictor.get_input_handle('token_type_ids')
43
+
44
+ #
45
+ self.output_names = self.predictor.get_output_names()
46
+
47
+ self._punc_list = []
48
+ if not os.path.join(model_dir, 'vocab.txt'):
49
+ raise Exception("{}".format(os.path.join(model_dir, 'vocab.txt')))
50
+ with open(os.path.join(model_dir, 'vocab.txt'), 'r', encoding='utf-8') as f:
51
+ for line in f:
52
+ self._punc_list.append(line.strip())
53
+
54
+ self.tokenizer = ErnieTokenizer.from_pretrained(pretrained_token)
55
+
56
+ #
57
+ self('')
58
+
59
+ def _clean_text(self, text):
60
+ text = text.lower()
61
+ text = re.sub('[^A-Za-z0-9\u4e00-\u9fa5]', '', text)
62
+ text = re.sub(f'[{"".join([p for p in self._punc_list][1:])}]', '', text)
63
+ return text
64
+
65
+ #
66
+ def preprocess(self, text: str):
67
+ clean_text = self._clean_text(text)
68
+ if len(clean_text) == 0: return None
69
+ tokenized_input = self.tokenizer(list(clean_text), return_length=True, is_split_into_words=True)
70
+ input_ids = tokenized_input['input_ids']
71
+ seg_ids = tokenized_input['token_type_ids']
72
+ seq_len = tokenized_input['seq_len']
73
+ return input_ids, seg_ids, seq_len
74
+
75
+ def infer(self, input_ids: list, seg_ids: list):
76
+ #
77
+ self.input_ids_handle.reshape([1, len(input_ids)])
78
+ self.token_type_ids_handle.reshape([1, len(seg_ids)])
79
+ self.input_ids_handle.copy_from_cpu(np.array([input_ids]).astype('int64'))
80
+ self.token_type_ids_handle.copy_from_cpu(np.array([seg_ids]).astype('int64'))
81
+
82
+ # predictor
83
+ self.predictor.run()
84
+
85
+ #
86
+ output_handle = self.predictor.get_output_handle(self.output_names[0])
87
+ output_data = output_handle.copy_to_cpu()
88
+ return output_data
89
+
90
+ #
91
+ def postprocess(self, input_ids, seq_len, preds):
92
+ tokens = self.tokenizer.convert_ids_to_tokens(input_ids[1:seq_len - 1])
93
+ labels = preds[1:seq_len - 1].tolist()
94
+ assert len(tokens) == len(labels)
95
+
96
+ text = ''
97
+ for t, l in zip(tokens, labels):
98
+ text += t
99
+ if l != 0:
100
+ text += self._punc_list[l]
101
+ return text
102
+
103
+ def __call__(self, text: str) -> str:
104
+ #
105
+ input_ids, seg_ids, seq_len = self.preprocess(text)
106
+ preds = self.infer(input_ids=input_ids, seg_ids=seg_ids)
107
+ if len(preds.shape) == 2:
108
+ preds = preds[0]
109
+ text = self.postprocess(input_ids, seq_len, preds)
110
+ return text
utils/reader.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ import sys
5
+ from typing import List
6
+
7
+ import librosa
8
+ import numpy as np
9
+ import soundfile
10
+ from torch.utils.data import Dataset
11
+ from tqdm import tqdm
12
+
13
+ from utils.binary import DatasetReader
14
+
15
+
16
+ class CustomDataset(Dataset):
17
+ def __init__(self,
18
+ data_list_path,
19
+ processor,
20
+ mono=True,
21
+ language=None,
22
+ timestamps=False,
23
+ sample_rate=16000,
24
+ min_duration=0.5,
25
+ max_duration=30,
26
+ augment_config_path=None):
27
+ """
28
+ Args:
29
+ data_list_path:
30
+ processor: Whisper
31
+ mono: True
32
+ language:
33
+ timestamps:
34
+ sample_rate: 16000
35
+ min_duration: 0.5s
36
+ max_duration: 30s
37
+ augment_config_path:
38
+ """
39
+ super(CustomDataset, self).__init__()
40
+ assert min_duration >= 0.5, f"min_duration 0.5:{min_duration}"
41
+ assert max_duration <= 30, f"max_duration 30:{max_duration}"
42
+ self.data_list_path = data_list_path
43
+ self.processor = processor
44
+ self.data_list_path = data_list_path
45
+ self.sample_rate = sample_rate
46
+ self.mono = mono
47
+ self.language = language
48
+ self.timestamps = timestamps
49
+ self.min_duration = min_duration
50
+ self.max_duration = max_duration
51
+ self.vocab = self.processor.tokenizer.get_vocab()
52
+ self.timestamp_begin = self.vocab['<|notimestamps|>'] + 1
53
+ self.startoftranscript = self.vocab['<|startoftranscript|>']
54
+ self.endoftext = self.vocab['<|endoftext|>']
55
+ self.nocaptions = self.vocab['<|nocaptions|>']
56
+ self.data_list: List[dict] = []
57
+ #
58
+ self._load_data_list()
59
+ #
60
+ self.augment_configs = None
61
+ self.noises_path = None
62
+ self.speed_rates = None
63
+ if augment_config_path:
64
+ with open(augment_config_path, 'r', encoding='utf-8') as f:
65
+ self.augment_configs = json.load(f)
66
+
67
+ #
68
+ def _load_data_list(self):
69
+ if self.data_list_path.endswith(".header"):
70
+ #
71
+ self.dataset_reader = DatasetReader(data_header_path=self.data_list_path,
72
+ min_duration=self.min_duration,
73
+ max_duration=self.max_duration)
74
+ self.data_list = self.dataset_reader.get_keys()
75
+ else:
76
+ #
77
+ with open(self.data_list_path, 'r', encoding='utf-8') as f:
78
+ lines = f.readlines()
79
+ self.data_list = []
80
+ for line in tqdm(lines, desc=''):
81
+ if isinstance(line, str):
82
+ line = json.loads(line)
83
+ if not isinstance(line, dict): continue
84
+ #
85
+ if line["duration"] < self.min_duration:
86
+ continue
87
+ if self.max_duration != -1 and line["duration"] > self.max_duration:
88
+ continue
89
+ self.data_list.append(dict(line))
90
+
91
+ #
92
+ def _get_list_data(self, idx):
93
+ if self.data_list_path.endswith(".header"):
94
+ data_list = self.dataset_reader.get_data(self.data_list[idx])
95
+ else:
96
+ data_list = self.data_list[idx]
97
+ #
98
+ audio_file = data_list["audio"]['path']
99
+ transcript = data_list["sentences"] if self.timestamps else data_list["sentence"]
100
+ language = data_list["language"] if 'language' in data_list.keys() else None
101
+ if 'start_time' not in data_list["audio"].keys():
102
+ sample, sample_rate = soundfile.read(audio_file, dtype='float32')
103
+ else:
104
+ start_time, end_time = data_list["audio"]["start_time"], data_list["audio"]["end_time"]
105
+ #
106
+ sample, sample_rate = self.slice_from_file(audio_file, start=start_time, end=end_time)
107
+ sample = sample.T
108
+ #
109
+ if self.mono:
110
+ sample = librosa.to_mono(sample)
111
+ #
112
+ if self.augment_configs:
113
+ sample, sample_rate = self.augment(sample, sample_rate)
114
+ #
115
+ if self.sample_rate != sample_rate:
116
+ sample = self.resample(sample, orig_sr=sample_rate, target_sr=self.sample_rate)
117
+ return sample, sample_rate, transcript, language
118
+
119
+ def _load_timestamps_transcript(self, transcript: List[dict]):
120
+ assert isinstance(transcript, list), f"transcript list:{type(transcript)}"
121
+ data = dict()
122
+ labels = self.processor.tokenizer.prefix_tokens[:3]
123
+ for t in transcript:
124
+ #
125
+ start = t['start'] if round(t['start'] * 100) % 2 == 0 else t['start'] + 0.01
126
+ start = self.timestamp_begin + round(start * 100) // 2
127
+ end = t['end'] if round(t['end'] * 100) % 2 == 0 else t['end'] - 0.01
128
+ end = self.timestamp_begin + round(end * 100) // 2
129
+ label = self.processor(text=t['text']).input_ids[4:-1]
130
+ labels.extend([start])
131
+ labels.extend(label)
132
+ labels.extend([end])
133
+ data['labels'] = labels + [self.endoftext]
134
+ return data
135
+
136
+ def __getitem__(self, idx):
137
+ try:
138
+ #
139
+ sample, sample_rate, transcript, language = self._get_list_data(idx=idx)
140
+ #
141
+ self.processor.tokenizer.set_prefix_tokens(language=language if language is not None else self.language)
142
+ if len(transcript) > 0:
143
+ #
144
+ if self.timestamps:
145
+ data = self._load_timestamps_transcript(transcript=transcript)
146
+ #
147
+ data["input_features"] = self.processor(audio=sample, sampling_rate=self.sample_rate).input_features
148
+ else:
149
+ #
150
+ data = self.processor(audio=sample, sampling_rate=self.sample_rate, text=transcript)
151
+ else:
152
+ #
153
+ data = self.processor(audio=sample, sampling_rate=self.sample_rate)
154
+ data['labels'] = [self.startoftranscript, self.nocaptions, self.endoftext]
155
+ return data
156
+ except Exception as e:
157
+ print(f'idx:{idx} error - {e}', file=sys.stderr)
158
+ return self.__getitem__(random.randint(0, self.__len__() - 1))
159
+
160
+ def __len__(self):
161
+ return len(self.data_list)
162
+
163
+ #
164
+ @staticmethod
165
+ def slice_from_file(file, start, end):
166
+ sndfile = soundfile.SoundFile(file)
167
+ sample_rate = sndfile.samplerate
168
+ duration = round(float(len(sndfile)) / sample_rate, 3)
169
+ start = round(start, 3)
170
+ end = round(end, 3)
171
+ #
172
+ if start < 0.0: start += duration
173
+ if end < 0.0: end += duration
174
+ #
175
+ if start < 0.0: start = 0.0
176
+ if end > duration: end = duration
177
+ if end < 0.0:
178
+ raise ValueError("(%f s)" % end)
179
+ if start > end:
180
+ raise ValueError("(%f s)(%f s)" % (start, end))
181
+ start_frame = int(start * sample_rate)
182
+ end_frame = int(end * sample_rate)
183
+ sndfile.seek(start_frame)
184
+ sample = sndfile.read(frames=end_frame - start_frame, dtype='float32')
185
+ return sample, sample_rate
186
+
187
+ #
188
+ def augment(self, sample, sample_rate):
189
+ for config in self.augment_configs:
190
+ if config['type'] == 'speed' and random.random() < config['prob']:
191
+ if self.speed_rates is None:
192
+ min_speed_rate, max_speed_rate, num_rates = config['params']['min_speed_rate'], \
193
+ config['params']['max_speed_rate'], config['params']['num_rates']
194
+ self.speed_rates = np.linspace(min_speed_rate, max_speed_rate, num_rates, endpoint=True)
195
+ rate = random.choice(self.speed_rates)
196
+ sample = self.change_speed(sample, speed_rate=rate)
197
+ if config['type'] == 'shift' and random.random() < config['prob']:
198
+ min_shift_ms, max_shift_ms = config['params']['min_shift_ms'], config['params']['max_shift_ms']
199
+ shift_ms = random.randint(min_shift_ms, max_shift_ms)
200
+ sample = self.shift(sample, sample_rate, shift_ms=shift_ms)
201
+ if config['type'] == 'volume' and random.random() < config['prob']:
202
+ min_gain_dBFS, max_gain_dBFS = config['params']['min_gain_dBFS'], config['params']['max_gain_dBFS']
203
+ gain = random.randint(min_gain_dBFS, max_gain_dBFS)
204
+ sample = self.volume(sample, gain=gain)
205
+ if config['type'] == 'resample' and random.random() < config['prob']:
206
+ new_sample_rates = config['params']['new_sample_rates']
207
+ new_sample_rate = np.random.choice(new_sample_rates)
208
+ sample = self.resample(sample, orig_sr=sample_rate, target_sr=new_sample_rate)
209
+ sample_rate = new_sample_rate
210
+ if config['type'] == 'noise' and random.random() < config['prob']:
211
+ min_snr_dB, max_snr_dB = config['params']['min_snr_dB'], config['params']['max_snr_dB']
212
+ if self.noises_path is None:
213
+ self.noises_path = []
214
+ noise_dir = config['params']['noise_dir']
215
+ if os.path.exists(noise_dir):
216
+ for file in os.listdir(noise_dir):
217
+ self.noises_path.append(os.path.join(noise_dir, file))
218
+ noise_path = random.choice(self.noises_path)
219
+ snr_dB = random.randint(min_snr_dB, max_snr_dB)
220
+ sample = self.add_noise(sample, sample_rate, noise_path=noise_path, snr_dB=snr_dB)
221
+ return sample, sample_rate
222
+
223
+ #
224
+ @staticmethod
225
+ def change_speed(sample, speed_rate):
226
+ if speed_rate == 1.0:
227
+ return sample
228
+ if speed_rate <= 0:
229
+ raise ValueError("error")
230
+ old_length = sample.shape[0]
231
+ new_length = int(old_length / speed_rate)
232
+ old_indices = np.arange(old_length)
233
+ new_indices = np.linspace(start=0, stop=old_length, num=new_length)
234
+ sample = np.interp(new_indices, old_indices, sample).astype(np.float32)
235
+ return sample
236
+
237
+ #
238
+ @staticmethod
239
+ def shift(sample, sample_rate, shift_ms):
240
+ duration = sample.shape[0] / sample_rate
241
+ if abs(shift_ms) / 1000.0 > duration:
242
+ raise ValueError("shift_ms")
243
+ shift_samples = int(shift_ms * sample_rate / 1000)
244
+ if shift_samples > 0:
245
+ sample[:-shift_samples] = sample[shift_samples:]
246
+ sample[-shift_samples:] = 0
247
+ elif shift_samples < 0:
248
+ sample[-shift_samples:] = sample[:shift_samples]
249
+ sample[:-shift_samples] = 0
250
+ return sample
251
+
252
+ #
253
+ @staticmethod
254
+ def volume(sample, gain):
255
+ sample *= 10.**(gain / 20.)
256
+ return
257
+
258
+ #
259
+ @staticmethod
260
+ def resample(sample, orig_sr, target_sr):
261
+ sample = librosa.resample(sample, orig_sr=orig_sr, target_sr=target_sr)
262
+ return sample
263
+
264
+ #
265
+ def add_noise(self, sample, sample_rate, noise_path, snr_dB, max_gain_db=300.0):
266
+ noise_sample, sr = librosa.load(noise_path, sr=sample_rate)
267
+ #
268
+ target_db = -20
269
+ gain = min(max_gain_db, target_db - self.rms_db(sample))
270
+ sample *= 10. ** (gain / 20.)
271
+ #
272
+ sample_rms_db, noise_rms_db = self.rms_db(sample), self.rms_db(noise_sample)
273
+ noise_gain_db = min(sample_rms_db - noise_rms_db - snr_dB, max_gain_db)
274
+ noise_sample *= 10. ** (noise_gain_db / 20.)
275
+ #
276
+ if noise_sample.shape[0] < sample.shape[0]:
277
+ diff_duration = sample.shape[0] - noise_sample.shape[0]
278
+ noise_sample = np.pad(noise_sample, (0, diff_duration), 'wrap')
279
+ elif noise_sample.shape[0] > sample.shape[0]:
280
+ start_frame = random.randint(0, noise_sample.shape[0] - sample.shape[0])
281
+ noise_sample = noise_sample[start_frame:sample.shape[0] + start_frame]
282
+ sample += noise_sample
283
+ return sample
284
+
285
+ @staticmethod
286
+ def rms_db(sample):
287
+ mean_square = np.mean(sample ** 2)
288
+ return 10 * np.log10(mean_square)
289
+
utils/utils.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import tarfile
4
+ import urllib.request
5
+
6
+ from tqdm import tqdm
7
+
8
+
9
+ def print_arguments(args):
10
+ print("----------- Configuration Arguments -----------")
11
+ for arg, value in vars(args).items():
12
+ print("%s: %s" % (arg, value))
13
+ print("------------------------------------------------")
14
+
15
+
16
+ def strtobool(val):
17
+ val = val.lower()
18
+ if val in ('y', 'yes', 't', 'true', 'on', '1'):
19
+ return True
20
+ elif val in ('n', 'no', 'f', 'false', 'off', '0'):
21
+ return False
22
+ else:
23
+ raise ValueError("invalid truth value %r" % (val,))
24
+
25
+
26
+ def str_none(val):
27
+ if val == 'None':
28
+ return None
29
+ else:
30
+ return val
31
+
32
+
33
+ def add_arguments(argname, type, default, help, argparser, **kwargs):
34
+ type = strtobool if type == bool else type
35
+ type = str_none if type == str else type
36
+ argparser.add_argument("--" + argname,
37
+ default=default,
38
+ type=type,
39
+ help=help + ' Default: %(default)s.',
40
+ **kwargs)
41
+
42
+
43
+ def md5file(fname):
44
+ hash_md5 = hashlib.md5()
45
+ f = open(fname, "rb")
46
+ for chunk in iter(lambda: f.read(4096), b""):
47
+ hash_md5.update(chunk)
48
+ f.close()
49
+ return hash_md5.hexdigest()
50
+
51
+
52
+ def download(url, md5sum, target_dir):
53
+ """Download file from url to target_dir, and check md5sum."""
54
+ if not os.path.exists(target_dir): os.makedirs(target_dir)
55
+ filepath = os.path.join(target_dir, url.split("/")[-1])
56
+ if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
57
+ print(f"Downloading {url} to {filepath} ...")
58
+ with urllib.request.urlopen(url) as source, open(filepath, "wb") as output:
59
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True,
60
+ unit_divisor=1024) as loop:
61
+ while True:
62
+ buffer = source.read(8192)
63
+ if not buffer:
64
+ break
65
+
66
+ output.write(buffer)
67
+ loop.update(len(buffer))
68
+ print(f"\nMD5 Chesksum {filepath} ...")
69
+ if not md5file(filepath) == md5sum:
70
+ raise RuntimeError("MD5 checksum failed.")
71
+ else:
72
+ print(f"File exists, skip downloading. ({filepath})")
73
+ return filepath
74
+
75
+
76
+ def unpack(filepath, target_dir, rm_tar=False):
77
+ """Unpack the file to the target_dir."""
78
+ print("Unpacking %s ..." % filepath)
79
+ tar = tarfile.open(filepath)
80
+ tar.extractall(target_dir)
81
+ tar.close()
82
+ if rm_tar:
83
+ os.remove(filepath)
84
+
85
+
86
+ def make_inputs_require_grad(module, input, output):
87
+ output.requires_grad_(True)