Spaces:
Running
Running
Synced repo using 'sync_with_huggingface' Github Action
Browse files- Dockerfile +69 -0
- LICENSE +201 -0
- configs/augmentation.json +43 -0
- dataset/test.mp3 +0 -0
- download.py +44 -0
- evaluation.py +96 -0
- finetune.py +160 -0
- infer_ct2.py +46 -0
- infer_server.py +143 -0
- infer_tfs.py +43 -0
- merge_lora.py +47 -0
- requirements.txt +21 -0
- run.sh +22 -0
- static/index.css +109 -0
- static/record.js +229 -0
- static/record.png +0 -0
- static/recording.gif +0 -0
- templates/index.html +167 -0
- utils/__init__.py +0 -0
- utils/binary.py +72 -0
- utils/callback.py +37 -0
- utils/data_utils.py +65 -0
- utils/model_utils.py +20 -0
- utils/pun_predictor.py +110 -0
- utils/reader.py +289 -0
- utils/utils.py +87 -0
Dockerfile
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04
|
2 |
+
|
3 |
+
# Use Python 3.11 for better Python perf
|
4 |
+
# Update the package lists and install necessary dependencies
|
5 |
+
RUN apt-get update && apt-get install -y \
|
6 |
+
software-properties-common \
|
7 |
+
&& add-apt-repository -y ppa:deadsnakes/ppa \
|
8 |
+
&& apt-get update \
|
9 |
+
&& apt-get install -y python3.11 python3.11-dev
|
10 |
+
|
11 |
+
# Set Python 3.11 as the default version (for python3)
|
12 |
+
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1
|
13 |
+
|
14 |
+
# Download get-pip.py script
|
15 |
+
RUN apt install curl -y
|
16 |
+
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
17 |
+
|
18 |
+
# Install pip for Python 3.11
|
19 |
+
RUN python3 get-pip.py
|
20 |
+
|
21 |
+
# Verify Python and pip versions
|
22 |
+
RUN python3 --version && pip3.11 --version
|
23 |
+
|
24 |
+
# Set pip3.11 as the default pip command
|
25 |
+
RUN update-alternatives --install /usr/bin/pip3 pip3 /usr/local/lib/python3.11/dist-packages/pip 1
|
26 |
+
|
27 |
+
ENV PYTHONUNBUFFERED=1
|
28 |
+
|
29 |
+
# Install necessary dependencies
|
30 |
+
# RUN apt-get update && \
|
31 |
+
# apt-get install -y python3-pip
|
32 |
+
|
33 |
+
# Set the working directory. /app is mounted to the container with -v,
|
34 |
+
# but we want to have the right cwd for uvicorn command below
|
35 |
+
RUN mkdir /app
|
36 |
+
# WORKDIR /app
|
37 |
+
|
38 |
+
# # Copy the app code and requirements filed
|
39 |
+
# COPY . /app
|
40 |
+
# COPY requirements.txt .
|
41 |
+
# WORKDIR $PYSETUP_PATH
|
42 |
+
COPY ./requirements.txt /app
|
43 |
+
|
44 |
+
|
45 |
+
COPY ./utils /app/utils
|
46 |
+
COPY ./static /app/static
|
47 |
+
COPY ./templates /app/templates
|
48 |
+
COPY ./infer_server.py /app/infer_server.py
|
49 |
+
COPY ./download.py /app/download.py
|
50 |
+
|
51 |
+
WORKDIR /app
|
52 |
+
|
53 |
+
|
54 |
+
# Install the app dependencies
|
55 |
+
# RUN pip3 install -r requirements.txt
|
56 |
+
|
57 |
+
RUN --mount=type=cache,target=/root/.cache/pip \
|
58 |
+
pip3 install -r requirements.txt
|
59 |
+
|
60 |
+
# Expose the FastAPI port
|
61 |
+
EXPOSE 5001
|
62 |
+
|
63 |
+
# Start the FastAPI app using Uvicorn web server
|
64 |
+
# CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "14000", "--limit-concurrency", "1000"]
|
65 |
+
RUN python3 download.py
|
66 |
+
|
67 |
+
CMD ["python3", "infer_server.py", "--host=0.0.0.0", "--port=5001", "--model_path=models/sam2ai/whisper-odia-small-finetune-int8-ct2", "--num_workers=2"]
|
68 |
+
|
69 |
+
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
configs/augmentation.json
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"type": "resample",
|
4 |
+
"params": {
|
5 |
+
"new_sample_rates": [8000, 32000, 44100]
|
6 |
+
},
|
7 |
+
"prob": 0.0
|
8 |
+
},
|
9 |
+
{
|
10 |
+
"type": "noise",
|
11 |
+
"params": {
|
12 |
+
"min_snr_dB": 10,
|
13 |
+
"max_snr_dB": 50,
|
14 |
+
"noise_dir": "dataset/noise"
|
15 |
+
},
|
16 |
+
"prob": 0.2
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"type": "speed",
|
20 |
+
"params": {
|
21 |
+
"min_speed_rate": 0.9,
|
22 |
+
"max_speed_rate": 1.1,
|
23 |
+
"num_rates": 3
|
24 |
+
},
|
25 |
+
"prob": 0.5
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"type": "shift",
|
29 |
+
"params": {
|
30 |
+
"min_shift_ms": -5,
|
31 |
+
"max_shift_ms": 5
|
32 |
+
},
|
33 |
+
"prob": 0.0
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"type": "volume",
|
37 |
+
"params": {
|
38 |
+
"min_gain_dBFS": -15,
|
39 |
+
"max_gain_dBFS": 15
|
40 |
+
},
|
41 |
+
"prob": 0.5
|
42 |
+
}
|
43 |
+
]
|
dataset/test.mp3
ADDED
Binary file (61.7 kB). View file
|
|
download.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import requests
|
3 |
+
import os
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
def download_file(url, path):
|
7 |
+
response = requests.get(url, stream=True)
|
8 |
+
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
9 |
+
block_size = 1024 #1 Kbyte
|
10 |
+
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
11 |
+
|
12 |
+
with open(path, 'wb') as file:
|
13 |
+
for data in response.iter_content(block_size):
|
14 |
+
progress_bar.update(len(data))
|
15 |
+
file.write(data)
|
16 |
+
|
17 |
+
progress_bar.close()
|
18 |
+
|
19 |
+
def download_model(model_name, destination_folder="models"):
|
20 |
+
# Define the base URL and headers for the Hugging Face API
|
21 |
+
base_url = f"https://huggingface.co/{model_name}/resolve/main"
|
22 |
+
headers = {"User-Agent": "Hugging Face Python"}
|
23 |
+
|
24 |
+
# Send a GET request to the Hugging Face API to get a list of all files
|
25 |
+
response = requests.get(f"https://huggingface.co/api/models/{model_name}", headers=headers)
|
26 |
+
response.raise_for_status()
|
27 |
+
|
28 |
+
# Extract the list of files from the response JSON
|
29 |
+
files_to_download = [file["rfilename"] for file in response.json()["siblings"]]
|
30 |
+
|
31 |
+
# Ensure the directory exists
|
32 |
+
os.makedirs(f"{destination_folder}/{model_name}", exist_ok=True)
|
33 |
+
|
34 |
+
# Download each file
|
35 |
+
for file in files_to_download:
|
36 |
+
print(f"Downloading {file}...")
|
37 |
+
download_file(f"{base_url}/{file}", f"{destination_folder}/{model_name}/{file}")
|
38 |
+
|
39 |
+
if __name__ == "__main__":
|
40 |
+
# parser = argparse.ArgumentParser()
|
41 |
+
# parser.add_argument("model_name", type=str, default="sam2ai/whisper-odia-small-finetune-int8-ct2", help="Name of the model to download.")
|
42 |
+
# args = parser.parse_args()
|
43 |
+
|
44 |
+
download_model("sam2ai/whisper-odia-small-finetune-int8-ct2")
|
evaluation.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import functools
|
3 |
+
import gc
|
4 |
+
import os
|
5 |
+
|
6 |
+
import evaluate
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
from tqdm import tqdm
|
11 |
+
from transformers import WhisperForConditionalGeneration, WhisperProcessor
|
12 |
+
|
13 |
+
from utils.data_utils import DataCollatorSpeechSeq2SeqWithPadding, remove_punctuation, to_simple
|
14 |
+
from utils.reader import CustomDataset
|
15 |
+
from utils.utils import print_arguments, add_arguments
|
16 |
+
|
17 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
18 |
+
add_arg = functools.partial(add_arguments, argparser=parser)
|
19 |
+
add_arg("test_data", type=str, default="dataset/test.json", help="测试集的路径")
|
20 |
+
add_arg("model_path", type=str, default="models/whisper-tiny-finetune", help="合并模型的路径,或者是huggingface上模型的名称")
|
21 |
+
add_arg("batch_size", type=int, default=16, help="评估的batch size")
|
22 |
+
add_arg("num_workers", type=int, default=8, help="读取数据的线程数量")
|
23 |
+
add_arg("language", type=str, default="Chinese", help="设置语言,可全称也可简写,如果为None则评估的是多语言")
|
24 |
+
add_arg("remove_pun", type=bool, default=True, help="是否移除标点符号")
|
25 |
+
add_arg("to_simple", type=bool, default=True, help="是否转为简体中文")
|
26 |
+
add_arg("timestamps", type=bool, default=False, help="评估时是否使用时间戳数据")
|
27 |
+
add_arg("min_audio_len", type=float, default=0.5, help="最小的音频长度,单位秒")
|
28 |
+
add_arg("max_audio_len", type=float, default=30, help="最大的音频长度,单位秒")
|
29 |
+
add_arg("local_files_only", type=bool, default=True, help="是否只在本地加载模型,不尝试下载")
|
30 |
+
add_arg("task", type=str, default="transcribe", choices=['transcribe', 'translate'], help="模型的任务")
|
31 |
+
add_arg("metric", type=str, default="cer", choices=['cer', 'wer'], help="评估方式")
|
32 |
+
args = parser.parse_args()
|
33 |
+
print_arguments(args)
|
34 |
+
|
35 |
+
# 判断模型路径是否合法
|
36 |
+
assert 'openai' == os.path.dirname(args.model_path) or os.path.exists(args.model_path), \
|
37 |
+
f"模型文件{args.model_path}不存在,请检查是否已经成功合并模型,或者是否为huggingface存在模型"
|
38 |
+
# 获取Whisper的数据处理器,这个包含了特征提取器、tokenizer
|
39 |
+
processor = WhisperProcessor.from_pretrained(args.model_path,
|
40 |
+
language=args.language,
|
41 |
+
task=args.task,
|
42 |
+
no_timestamps=not args.timestamps,
|
43 |
+
local_files_only=args.local_files_only)
|
44 |
+
forced_decoder_ids = processor.get_decoder_prompt_ids()
|
45 |
+
# 获取模型
|
46 |
+
model = WhisperForConditionalGeneration.from_pretrained(args.model_path,
|
47 |
+
device_map="auto",
|
48 |
+
local_files_only=args.local_files_only)
|
49 |
+
model.eval()
|
50 |
+
|
51 |
+
# 获取测试数据
|
52 |
+
test_dataset = CustomDataset(data_list_path=args.test_data,
|
53 |
+
processor=processor,
|
54 |
+
timestamps=args.timestamps,
|
55 |
+
min_duration=args.min_audio_len,
|
56 |
+
max_duration=args.max_audio_len)
|
57 |
+
print(f"测试数据:{len(test_dataset)}")
|
58 |
+
|
59 |
+
# 数据padding器
|
60 |
+
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
|
61 |
+
eval_dataloader = DataLoader(test_dataset, batch_size=args.batch_size,
|
62 |
+
num_workers=args.num_workers, collate_fn=data_collator)
|
63 |
+
|
64 |
+
# 获取评估方法
|
65 |
+
metric = evaluate.load(args.metric)
|
66 |
+
|
67 |
+
# 开始评估
|
68 |
+
for step, batch in enumerate(tqdm(eval_dataloader)):
|
69 |
+
with torch.cuda.amp.autocast():
|
70 |
+
with torch.no_grad():
|
71 |
+
generated_tokens = (
|
72 |
+
model.generate(
|
73 |
+
input_features=batch["input_features"].cuda(),
|
74 |
+
decoder_input_ids=batch["labels"][:, :4].cuda(),
|
75 |
+
forced_decoder_ids=forced_decoder_ids,
|
76 |
+
max_new_tokens=255).cpu().numpy())
|
77 |
+
labels = batch["labels"].cpu().numpy()
|
78 |
+
labels = np.where(labels != -100, labels, processor.tokenizer.pad_token_id)
|
79 |
+
# 将预测和实际的token转换为文本
|
80 |
+
decoded_preds = processor.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
81 |
+
decoded_labels = processor.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
82 |
+
# 删除标点符号
|
83 |
+
if args.remove_pun:
|
84 |
+
decoded_preds = remove_punctuation(decoded_preds)
|
85 |
+
decoded_labels = remove_punctuation(decoded_labels)
|
86 |
+
# 将繁体中文总成简体中文
|
87 |
+
if args.to_simple:
|
88 |
+
decoded_preds = to_simple(decoded_preds)
|
89 |
+
decoded_labels = to_simple(decoded_labels)
|
90 |
+
metric.add_batch(predictions=decoded_preds, references=decoded_labels)
|
91 |
+
# 删除计算的记录
|
92 |
+
del generated_tokens, labels, batch
|
93 |
+
gc.collect()
|
94 |
+
# 计算评估结果
|
95 |
+
m = metric.compute()
|
96 |
+
print(f"评估结果:{args.metric}={round(m, 5)}")
|
finetune.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import functools
|
3 |
+
import os
|
4 |
+
import platform
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from peft import LoraConfig, get_peft_model, AdaLoraConfig, PeftModel, prepare_model_for_kbit_training
|
8 |
+
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, WhisperForConditionalGeneration, WhisperProcessor
|
9 |
+
|
10 |
+
from utils.callback import SavePeftModelCallback
|
11 |
+
from utils.data_utils import DataCollatorSpeechSeq2SeqWithPadding
|
12 |
+
from utils.model_utils import load_from_checkpoint
|
13 |
+
from utils.reader import CustomDataset
|
14 |
+
from utils.utils import print_arguments, make_inputs_require_grad, add_arguments
|
15 |
+
|
16 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
17 |
+
add_arg = functools.partial(add_arguments, argparser=parser)
|
18 |
+
add_arg("train_data", type=str, default="dataset/train.json", help="")
|
19 |
+
add_arg("test_data", type=str, default="dataset/test.json", help="")
|
20 |
+
add_arg("base_model", type=str, default="openai/whisper-tiny", help="Whisper")
|
21 |
+
add_arg("output_dir", type=str, default="output/", help="")
|
22 |
+
add_arg("warmup_steps", type=int, default=50, help="")
|
23 |
+
add_arg("logging_steps", type=int, default=100, help="")
|
24 |
+
add_arg("eval_steps", type=int, default=1000, help="")
|
25 |
+
add_arg("save_steps", type=int, default=1000, help="")
|
26 |
+
add_arg("num_workers", type=int, default=8, help="")
|
27 |
+
add_arg("learning_rate", type=float, default=1e-3, help="")
|
28 |
+
add_arg("min_audio_len", type=float, default=0.5, help="")
|
29 |
+
add_arg("max_audio_len", type=float, default=30, help="")
|
30 |
+
add_arg("use_adalora", type=bool, default=True, help="AdaLora/Lora")
|
31 |
+
add_arg("fp16", type=bool, default=True, help="fp16")
|
32 |
+
add_arg("use_8bit", type=bool, default=False, help="8 bit")
|
33 |
+
add_arg("timestamps", type=bool, default=False, help="")
|
34 |
+
add_arg("local_files_only", type=bool, default=False, help="")
|
35 |
+
add_arg("num_train_epochs", type=int, default=3, help="")
|
36 |
+
add_arg("language", type=str, default="bn", help="")
|
37 |
+
add_arg("task", type=str, default="transcribe", choices=['transcribe', 'translate'], help="模型的任务")
|
38 |
+
add_arg("augment_config_path", type=str, default=None, help="")
|
39 |
+
add_arg("resume_from_checkpoint", type=str, default=None, help="")
|
40 |
+
add_arg("per_device_train_batch_size", type=int, default=8, help="batch size")
|
41 |
+
add_arg("per_device_eval_batch_size", type=int, default=8, help="batch size")
|
42 |
+
add_arg("gradient_accumulation_steps", type=int, default=1, help="")
|
43 |
+
|
44 |
+
args = parser.parse_args()
|
45 |
+
print_arguments(args)
|
46 |
+
|
47 |
+
|
48 |
+
# Whisper tokenizer
|
49 |
+
processor = WhisperProcessor.from_pretrained(args.base_model,
|
50 |
+
language=args.language,
|
51 |
+
task=args.task,
|
52 |
+
no_timestamps=not args.timestamps,
|
53 |
+
local_files_only=args.local_files_only)
|
54 |
+
|
55 |
+
#
|
56 |
+
train_dataset = CustomDataset(data_list_path=args.train_data,
|
57 |
+
processor=processor,
|
58 |
+
language=args.language,
|
59 |
+
timestamps=args.timestamps,
|
60 |
+
min_duration=args.min_audio_len,
|
61 |
+
max_duration=args.max_audio_len,
|
62 |
+
augment_config_path=args.augment_config_path)
|
63 |
+
test_dataset = CustomDataset(data_list_path=args.test_data,
|
64 |
+
processor=processor,
|
65 |
+
language=args.language,
|
66 |
+
timestamps=args.timestamps,
|
67 |
+
min_duration=args.min_audio_len,
|
68 |
+
max_duration=args.max_audio_len)
|
69 |
+
print(f"len train - {len(train_dataset)} test len - {len(test_dataset)}")
|
70 |
+
|
71 |
+
# padding
|
72 |
+
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
|
73 |
+
|
74 |
+
# Whisper
|
75 |
+
device_map = "auto"
|
76 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
77 |
+
ddp = world_size != 1
|
78 |
+
if ddp:
|
79 |
+
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
80 |
+
|
81 |
+
#
|
82 |
+
model = WhisperForConditionalGeneration.from_pretrained(args.base_model,
|
83 |
+
load_in_8bit=args.use_8bit,
|
84 |
+
device_map=device_map,
|
85 |
+
local_files_only=args.local_files_only)
|
86 |
+
model.config.forced_decoder_ids = None
|
87 |
+
model.config.suppress_tokens = []
|
88 |
+
#
|
89 |
+
model = prepare_model_for_kbit_training(model)
|
90 |
+
# forward,req grad
|
91 |
+
model.model.encoder.conv1.register_forward_hook(make_inputs_require_grad)
|
92 |
+
|
93 |
+
print('加载LoRA模块...')
|
94 |
+
if args.resume_from_checkpoint:
|
95 |
+
#
|
96 |
+
print("Loading adapters from checkpoint.")
|
97 |
+
model = PeftModel.from_pretrained(model, args.resume_from_checkpoint, is_trainable=True)
|
98 |
+
else:
|
99 |
+
print(f'adding LoRA modules...')
|
100 |
+
target_modules = ["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"]
|
101 |
+
print(target_modules)
|
102 |
+
if args.use_adalora:
|
103 |
+
config = AdaLoraConfig(init_r=12, target_r=4, beta1=0.85, beta2=0.85, tinit=200, tfinal=1000, deltaT=10,
|
104 |
+
lora_alpha=32, lora_dropout=0.1, orth_reg_weight=0.5, target_modules=target_modules)
|
105 |
+
else:
|
106 |
+
config = LoraConfig(r=32, lora_alpha=64, target_modules=target_modules, lora_dropout=0.05, bias="none")
|
107 |
+
model = get_peft_model(model, config)
|
108 |
+
|
109 |
+
output_dir = os.path.join(args.output_dir, os.path.basename(args.base_model))
|
110 |
+
#
|
111 |
+
training_args = \
|
112 |
+
Seq2SeqTrainingArguments(output_dir=output_dir, # Directory to save checkpoints
|
113 |
+
per_device_train_batch_size=args.per_device_train_batch_size, # Training batch_size size
|
114 |
+
per_device_eval_batch_size=args.per_device_eval_batch_size, # Eval batch_size
|
115 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps, # Cumulative steps of training gradient
|
116 |
+
learning_rate=args.learning_rate, # learning rate size
|
117 |
+
warmup_steps=args.warmup_steps, # Warm-up steps
|
118 |
+
num_train_epochs=args.num_train_epochs, # epochs
|
119 |
+
save_strategy="steps", #
|
120 |
+
evaluation_strategy="steps", #
|
121 |
+
load_best_model_at_end=True, #
|
122 |
+
fp16=args.fp16, #
|
123 |
+
report_to=["tensorboard"], # tensorboard
|
124 |
+
save_steps=args.save_steps, #
|
125 |
+
eval_steps=args.eval_steps, #
|
126 |
+
save_total_limit=5, #
|
127 |
+
optim='adamw_torch', #
|
128 |
+
ddp_find_unused_parameters=False if ddp else None, #
|
129 |
+
dataloader_num_workers=args.num_workers, #
|
130 |
+
logging_steps=args.logging_steps, #
|
131 |
+
remove_unused_columns=False, #
|
132 |
+
label_names=["labels"]) #
|
133 |
+
|
134 |
+
if training_args.local_rank == 0 or training_args.local_rank == -1:
|
135 |
+
print('=' * 90)
|
136 |
+
model.print_trainable_parameters()
|
137 |
+
print('=' * 90)
|
138 |
+
|
139 |
+
# Pytorch2.0
|
140 |
+
if torch.__version__ >= "2" and platform.system().lower() == 'windows':
|
141 |
+
model = torch.compile(model)
|
142 |
+
|
143 |
+
#
|
144 |
+
trainer = Seq2SeqTrainer(args=training_args,
|
145 |
+
model=model,
|
146 |
+
train_dataset=train_dataset,
|
147 |
+
eval_dataset=test_dataset,
|
148 |
+
data_collator=data_collator,
|
149 |
+
tokenizer=processor.feature_extractor,
|
150 |
+
callbacks=[SavePeftModelCallback])
|
151 |
+
model.config.use_cache = False
|
152 |
+
trainer._load_from_checkpoint = load_from_checkpoint
|
153 |
+
|
154 |
+
#
|
155 |
+
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
|
156 |
+
|
157 |
+
#
|
158 |
+
trainer.save_state()
|
159 |
+
if training_args.local_rank == 0 or training_args.local_rank == -1:
|
160 |
+
model.save_pretrained(os.path.join(output_dir, "checkpoint-final"))
|
infer_ct2.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import functools
|
3 |
+
import os
|
4 |
+
|
5 |
+
from faster_whisper import WhisperModel
|
6 |
+
|
7 |
+
from utils.utils import print_arguments, add_arguments
|
8 |
+
|
9 |
+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
10 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
11 |
+
add_arg = functools.partial(add_arguments, argparser=parser)
|
12 |
+
add_arg("audio_path", type=str, default="dataset/test.wav", help="")
|
13 |
+
add_arg("model_path", type=str, default="models/whisper-tiny-finetune-ct2", help="")
|
14 |
+
add_arg("language", type=str, default="zh", help="")
|
15 |
+
add_arg("use_gpu", type=bool, default=True, help="")
|
16 |
+
add_arg("use_int8", type=bool, default=False, help="int8")
|
17 |
+
add_arg("beam_size", type=int, default=10, help="")
|
18 |
+
add_arg("num_workers", type=int, default=1, help="")
|
19 |
+
add_arg("vad_filter", type=bool, default=False, help="")
|
20 |
+
add_arg("local_files_only", type=bool, default=True, help="")
|
21 |
+
args = parser.parse_args()
|
22 |
+
print_arguments(args)
|
23 |
+
|
24 |
+
#
|
25 |
+
assert os.path.exists(args.model_path), f"{args.model_path}"
|
26 |
+
#
|
27 |
+
if args.use_gpu:
|
28 |
+
if not args.use_int8:
|
29 |
+
model = WhisperModel(args.model_path, device="cuda", compute_type="float16", num_workers=args.num_workers,
|
30 |
+
local_files_only=args.local_files_only)
|
31 |
+
else:
|
32 |
+
model = WhisperModel(args.model_path, device="cuda", compute_type="int8_float16", num_workers=args.num_workers,
|
33 |
+
local_files_only=args.local_files_only)
|
34 |
+
else:
|
35 |
+
model = WhisperModel(args.model_path, device="cpu", compute_type="int8", num_workers=args.num_workers,
|
36 |
+
local_files_only=args.local_files_only)
|
37 |
+
#
|
38 |
+
_, _ = model.transcribe("dataset/test.wav", beam_size=5)
|
39 |
+
|
40 |
+
|
41 |
+
#
|
42 |
+
segments, info = model.transcribe(args.audio_path, beam_size=args.beam_size, language=args.language,
|
43 |
+
vad_filter=args.vad_filter)
|
44 |
+
for segment in segments:
|
45 |
+
text = segment.text
|
46 |
+
print(f"[{round(segment.start, 2)} - {round(segment.end, 2)}]:{text}\n")
|
infer_server.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import asyncio
|
3 |
+
import functools
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
from io import BytesIO
|
7 |
+
|
8 |
+
import uvicorn
|
9 |
+
from fastapi import FastAPI, BackgroundTasks, File, Body, UploadFile, Request
|
10 |
+
from fastapi.responses import StreamingResponse
|
11 |
+
from faster_whisper import WhisperModel
|
12 |
+
from starlette.staticfiles import StaticFiles
|
13 |
+
from starlette.templating import Jinja2Templates
|
14 |
+
from zhconv import convert
|
15 |
+
|
16 |
+
from utils.data_utils import remove_punctuation
|
17 |
+
from utils.utils import add_arguments, print_arguments
|
18 |
+
|
19 |
+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
20 |
+
|
21 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
22 |
+
add_arg = functools.partial(add_arguments, argparser=parser)
|
23 |
+
|
24 |
+
add_arg("host", type=str, default="0.0.0.0", help="")
|
25 |
+
add_arg("port", type=int, default=5000, help="")
|
26 |
+
add_arg("model_path", type=str, default="models/sam2ai/whisper-odia-small-finetune-int8-ct2", help="")
|
27 |
+
add_arg("use_gpu", type=bool, default=False, help="")
|
28 |
+
add_arg("use_int8", type=bool, default=True, help="")
|
29 |
+
add_arg("beam_size", type=int, default=10, help="")
|
30 |
+
add_arg("num_workers", type=int, default=2, help="")
|
31 |
+
add_arg("vad_filter", type=bool, default=True, help="")
|
32 |
+
add_arg("local_files_only", type=bool, default=True, help="")
|
33 |
+
args = parser.parse_args()
|
34 |
+
print_arguments(args)
|
35 |
+
|
36 |
+
#
|
37 |
+
assert os.path.exists(args.model_path), f"{args.model_path}"
|
38 |
+
#
|
39 |
+
if args.use_gpu:
|
40 |
+
if not args.use_int8:
|
41 |
+
model = WhisperModel(args.model_path, device="cuda", compute_type="float16",
|
42 |
+
num_workers=args.num_workers, local_files_only=args.local_files_only)
|
43 |
+
else:
|
44 |
+
model = WhisperModel(args.model_path, device="cuda",
|
45 |
+
compute_type="int8_float16", num_workers=args.num_workers,
|
46 |
+
local_files_only=args.local_files_only)
|
47 |
+
else:
|
48 |
+
model = WhisperModel(args.model_path, device="cpu",
|
49 |
+
compute_type="int8", num_workers=args.num_workers,
|
50 |
+
local_files_only=args.local_files_only)
|
51 |
+
|
52 |
+
#
|
53 |
+
# _, _ = model.transcribe("dataset/test.wav", beam_size=5)
|
54 |
+
|
55 |
+
app = FastAPI(title="")
|
56 |
+
app.mount('/static', StaticFiles(directory='static'), name='static')
|
57 |
+
templates = Jinja2Templates(directory="templates")
|
58 |
+
model_semaphore = None
|
59 |
+
|
60 |
+
|
61 |
+
def release_model_semaphore():
|
62 |
+
model_semaphore.release()
|
63 |
+
|
64 |
+
|
65 |
+
def recognition(file: File, to_simple: int,
|
66 |
+
remove_pun: int, language: str = "ory",
|
67 |
+
task: str = "transcribe"
|
68 |
+
):
|
69 |
+
|
70 |
+
segments, info = model.transcribe(file, beam_size=10, task=task, language=language, vad_filter=args.vad_filter)
|
71 |
+
for segment in segments:
|
72 |
+
text = segment.text
|
73 |
+
if to_simple == 1:
|
74 |
+
# text = convert(text, '')
|
75 |
+
pass
|
76 |
+
if remove_pun == 1:
|
77 |
+
# text = remove_punctuation(text)
|
78 |
+
pass
|
79 |
+
ret = {"result": text, "start": round(segment.start, 2), "end": round(segment.end, 2)}
|
80 |
+
#
|
81 |
+
yield json.dumps(ret).encode() + b"\0"
|
82 |
+
|
83 |
+
|
84 |
+
@app.post("/recognition_stream")
|
85 |
+
async def api_recognition_stream(
|
86 |
+
to_simple: int = Body(1, description="", embed=True),
|
87 |
+
remove_pun: int = Body(0, description="", embed=True),
|
88 |
+
language: str = Body("ory", description="", embed=True),
|
89 |
+
task: str = Body("transcribe", description="", embed=True),
|
90 |
+
audio: UploadFile = File(..., description="")
|
91 |
+
):
|
92 |
+
|
93 |
+
global model_semaphore
|
94 |
+
if language == "None": language = None
|
95 |
+
if model_semaphore is None:
|
96 |
+
model_semaphore = asyncio.Semaphore(5)
|
97 |
+
await model_semaphore.acquire()
|
98 |
+
contents = await audio.read()
|
99 |
+
data = BytesIO(contents)
|
100 |
+
generator = recognition(
|
101 |
+
file=data, to_simple=to_simple,
|
102 |
+
remove_pun=remove_pun, language=language,
|
103 |
+
task=task
|
104 |
+
)
|
105 |
+
background_tasks = BackgroundTasks()
|
106 |
+
background_tasks.add_task(release_model_semaphore)
|
107 |
+
return StreamingResponse(generator, background=background_tasks)
|
108 |
+
|
109 |
+
|
110 |
+
@app.post("/recognition")
|
111 |
+
async def api_recognition(
|
112 |
+
to_simple: int = Body(1, description="", embed=True),
|
113 |
+
remove_pun: int = Body(0, description="", embed=True),
|
114 |
+
language: str = Body("ory", description="", embed=True),
|
115 |
+
task: str = Body("transcribe", description="", embed=True),
|
116 |
+
audio: UploadFile = File(..., description="")
|
117 |
+
):
|
118 |
+
|
119 |
+
if language == "None":language=None
|
120 |
+
contents = await audio.read()
|
121 |
+
data = BytesIO(contents)
|
122 |
+
generator = recognition(
|
123 |
+
file=data, to_simple=to_simple,
|
124 |
+
remove_pun=remove_pun, language=language,
|
125 |
+
task=task
|
126 |
+
)
|
127 |
+
results = []
|
128 |
+
for output in generator:
|
129 |
+
output = json.loads(output[:-1].decode("utf-8"))
|
130 |
+
results.append(output)
|
131 |
+
ret = {"results": results, "code": 0}
|
132 |
+
return ret
|
133 |
+
|
134 |
+
|
135 |
+
@app.get("/")
|
136 |
+
async def index(request: Request):
|
137 |
+
return templates.TemplateResponse(
|
138 |
+
"index.html", {"request": request, "id": id}
|
139 |
+
)
|
140 |
+
|
141 |
+
|
142 |
+
if __name__ == '__main__':
|
143 |
+
uvicorn.run(app, host=args.host, port=args.port)
|
infer_tfs.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import functools
|
3 |
+
|
4 |
+
import librosa
|
5 |
+
from transformers import WhisperForConditionalGeneration, WhisperProcessor
|
6 |
+
|
7 |
+
from utils.utils import print_arguments, add_arguments
|
8 |
+
|
9 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
10 |
+
add_arg = functools.partial(add_arguments, argparser=parser)
|
11 |
+
add_arg("audio_path", type=str, default="dataset/test.wav", help="")
|
12 |
+
add_arg("model_path", type=str, default="models/whisper-tiny-finetune", help="")
|
13 |
+
add_arg("language", type=str, default="Oriya", help="")
|
14 |
+
add_arg("task", type=str, default="transcribe", choices=['transcribe', 'translate'], help="")
|
15 |
+
add_arg("local_files_only", type=bool, default=True, help="")
|
16 |
+
args = parser.parse_args()
|
17 |
+
print_arguments(args)
|
18 |
+
|
19 |
+
# Whisper
|
20 |
+
processor = WhisperProcessor.from_pretrained(args.model_path,
|
21 |
+
language=args.language,
|
22 |
+
task=args.task,
|
23 |
+
local_files_only=args.local_files_only)
|
24 |
+
forced_decoder_ids = processor.get_decoder_prompt_ids(language=args.language, task=args.task)
|
25 |
+
|
26 |
+
#
|
27 |
+
model = WhisperForConditionalGeneration.from_pretrained(args.model_path,
|
28 |
+
device_map="auto",
|
29 |
+
local_files_only=args.local_files_only).half()
|
30 |
+
model.eval()
|
31 |
+
|
32 |
+
#
|
33 |
+
sample, sr = librosa.load(args.audio_path, sr=16000)
|
34 |
+
duration = sample.shape[-1]/sr
|
35 |
+
assert duration < 30, f"This program is only suitable for inferring audio less than 30 seconds, the current audio {duration} seconds, use another inference program!"
|
36 |
+
|
37 |
+
#
|
38 |
+
input_features = processor(sample, sampling_rate=sr, return_tensors="pt", do_normalize=True).input_features.cuda().half()
|
39 |
+
#
|
40 |
+
predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids, max_new_tokens=256)
|
41 |
+
#
|
42 |
+
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
43 |
+
print(f"result :{transcription}")
|
merge_lora.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import functools
|
3 |
+
import os
|
4 |
+
|
5 |
+
from transformers import WhisperForConditionalGeneration, WhisperFeatureExtractor, WhisperTokenizerFast,\
|
6 |
+
WhisperProcessor
|
7 |
+
from peft import PeftModel, PeftConfig
|
8 |
+
from utils.utils import print_arguments, add_arguments
|
9 |
+
|
10 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
11 |
+
add_arg = functools.partial(add_arguments, argparser=parser)
|
12 |
+
add_arg("lora_model", type=str, default="output/whisper-tiny/checkpoint-best/", help="")
|
13 |
+
add_arg('output_dir', type=str, default='models/', help="")
|
14 |
+
add_arg("local_files_only", type=bool, default=False, help="")
|
15 |
+
args = parser.parse_args()
|
16 |
+
print_arguments(args)
|
17 |
+
|
18 |
+
#
|
19 |
+
assert os.path.exists(args.lora_model), f"{args.lora_model}"
|
20 |
+
# Lora
|
21 |
+
peft_config = PeftConfig.from_pretrained(args.lora_model)
|
22 |
+
# Whisper
|
23 |
+
base_model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path, device_map={"": "cpu"},
|
24 |
+
local_files_only=args.local_files_only)
|
25 |
+
# Lora
|
26 |
+
model = PeftModel.from_pretrained(base_model, args.lora_model, local_files_only=args.local_files_only)
|
27 |
+
feature_extractor = WhisperFeatureExtractor.from_pretrained(peft_config.base_model_name_or_path,
|
28 |
+
local_files_only=args.local_files_only)
|
29 |
+
tokenizer = WhisperTokenizerFast.from_pretrained(peft_config.base_model_name_or_path,
|
30 |
+
local_files_only=args.local_files_only)
|
31 |
+
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path,
|
32 |
+
local_files_only=args.local_files_only)
|
33 |
+
|
34 |
+
#
|
35 |
+
model = model.merge_and_unload()
|
36 |
+
model.train(False)
|
37 |
+
|
38 |
+
#
|
39 |
+
save_directory = os.path.join(args.output_dir, f'{os.path.basename(peft_config.base_model_name_or_path)}-finetune')
|
40 |
+
os.makedirs(save_directory, exist_ok=True)
|
41 |
+
|
42 |
+
#
|
43 |
+
model.save_pretrained(save_directory)
|
44 |
+
feature_extractor.save_pretrained(save_directory)
|
45 |
+
tokenizer.save_pretrained(save_directory)
|
46 |
+
processor.save_pretrained(save_directory)
|
47 |
+
print(f'model saved directory :{save_directory}')
|
requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy>=1.23.1
|
2 |
+
soundfile>=0.12.1
|
3 |
+
librosa>=0.10.0
|
4 |
+
dataclasses>=0.6
|
5 |
+
transformers>=4.31.0
|
6 |
+
bitsandbytes>=0.41.0
|
7 |
+
soundfile>=0.12.1
|
8 |
+
datasets>=2.11.0
|
9 |
+
evaluate>=0.4.0
|
10 |
+
faster-whisper>=0.7.0
|
11 |
+
jiwer>=2.5.1
|
12 |
+
peft>=0.4.0
|
13 |
+
accelerate>=0.21.0
|
14 |
+
zhconv>=1.4.2
|
15 |
+
tqdm>=4.62.1
|
16 |
+
soundcard>=0.4.2
|
17 |
+
uvicorn>=0.21.1
|
18 |
+
fastapi>=0.95.1
|
19 |
+
starlette>=0.26.1
|
20 |
+
tensorboardX>=2.2
|
21 |
+
python-multipart
|
run.sh
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 finetune.py --base_model=openai/whisper-tiny --use_8bit=False --per_device_train_batch_size=8 --per_device_eval_batch_size=8 --gradient_accumulation_steps=1
|
4 |
+
CUDA_VISIBLE_DEVICES=0 python merge_lora.py --lora_model=output/whisper-tiny/checkpoint-final
|
5 |
+
|
6 |
+
CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 finetune.py --base_model=openai/whisper-base --use_8bit=False --per_device_train_batch_size=8 --per_device_eval_batch_size=8 --gradient_accumulation_steps=1
|
7 |
+
CUDA_VISIBLE_DEVICES=0 python merge_lora.py --lora_model=output/whisper-base/checkpoint-final
|
8 |
+
|
9 |
+
CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 finetune.py --base_model=openai/whisper-small --use_8bit=True --per_device_train_batch_size=8 --per_device_eval_batch_size=8 --gradient_accumulation_steps=1
|
10 |
+
CUDA_VISIBLE_DEVICES=0 python merge_lora.py --lora_model=output/whisper-small/checkpoint-final
|
11 |
+
|
12 |
+
CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 finetune.py --base_model=openai/whisper-medium --use_8bit=True --per_device_train_batch_size=4 --per_device_eval_batch_size=2 --gradient_accumulation_steps=2
|
13 |
+
CUDA_VISIBLE_DEVICES=0 python merge_lora.py --lora_model=output/whisper-medium/checkpoint-final
|
14 |
+
|
15 |
+
CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 finetune.py --base_model=openai/whisper-large-v2 --use_8bit=True --per_device_train_batch_size=2 --per_device_eval_batch_size=2 --gradient_accumulation_steps=4
|
16 |
+
CUDA_VISIBLE_DEVICES=0 python merge_lora.py --lora_model=output/whisper-large-v2/checkpoint-final
|
17 |
+
|
18 |
+
CUDA_VISIBLE_DEVICES=0 python evaluation.py --model_path=models/whisper-tiny-finetune
|
19 |
+
CUDA_VISIBLE_DEVICES=0 python evaluation.py --model_path=models/whisper-base-finetune
|
20 |
+
CUDA_VISIBLE_DEVICES=0 python evaluation.py --model_path=models/whisper-small-finetune
|
21 |
+
CUDA_VISIBLE_DEVICES=0 python evaluation.py --model_path=models/whisper-medium-finetune
|
22 |
+
CUDA_VISIBLE_DEVICES=0 python evaluation.py --model_path=models/whisper-large-v2-finetune
|
static/index.css
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
* {
|
2 |
+
box-sizing: border-box;
|
3 |
+
}
|
4 |
+
|
5 |
+
body {
|
6 |
+
font-family: "Helvetica Neue", "Roboto", sans-serif;
|
7 |
+
background-color: #f2f2f2;
|
8 |
+
margin: 0;
|
9 |
+
padding: 0;
|
10 |
+
}
|
11 |
+
|
12 |
+
#header {
|
13 |
+
background-color: #fff;
|
14 |
+
color: #333;
|
15 |
+
display: flex;
|
16 |
+
justify-content: center;
|
17 |
+
align-items: center;
|
18 |
+
height: 80px;
|
19 |
+
}
|
20 |
+
|
21 |
+
h1 {
|
22 |
+
font-size: 36px;
|
23 |
+
margin: 0;
|
24 |
+
}
|
25 |
+
|
26 |
+
#content {
|
27 |
+
background-color: #fff;
|
28 |
+
border-radius: 10px;
|
29 |
+
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.2);
|
30 |
+
margin: 50px auto;
|
31 |
+
max-width: 800px;
|
32 |
+
padding: 20px;
|
33 |
+
}
|
34 |
+
|
35 |
+
#content div {
|
36 |
+
display: flex;
|
37 |
+
flex-wrap: wrap;
|
38 |
+
justify-content: space-between;
|
39 |
+
margin-bottom: 20px;
|
40 |
+
}
|
41 |
+
|
42 |
+
#content a {
|
43 |
+
background-color: #fff;
|
44 |
+
border-radius: 5px;
|
45 |
+
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.2);
|
46 |
+
color: #333;
|
47 |
+
padding: 10px;
|
48 |
+
text-align: center;
|
49 |
+
text-decoration: none;
|
50 |
+
transition: background-color 0.2s;
|
51 |
+
width: 20%;
|
52 |
+
}
|
53 |
+
|
54 |
+
#content a:hover {
|
55 |
+
background-color: #f2f2f2;
|
56 |
+
}
|
57 |
+
|
58 |
+
#content img {
|
59 |
+
cursor: pointer;
|
60 |
+
height: 50px;
|
61 |
+
transition: transform 0.2s;
|
62 |
+
width: 50px;
|
63 |
+
}
|
64 |
+
|
65 |
+
#content img:hover {
|
66 |
+
transform: scale(1.1);
|
67 |
+
}
|
68 |
+
|
69 |
+
#result {
|
70 |
+
background-color: #fff;
|
71 |
+
border-radius: 5px;
|
72 |
+
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.2);
|
73 |
+
padding: 10px;
|
74 |
+
}
|
75 |
+
|
76 |
+
#result textarea {
|
77 |
+
border: none;
|
78 |
+
border-radius: 5px;
|
79 |
+
font-size: 16px;
|
80 |
+
height: 100px;
|
81 |
+
margin-top: 10px;
|
82 |
+
padding: 10px;
|
83 |
+
resize: none;
|
84 |
+
width: 100%;
|
85 |
+
}
|
86 |
+
|
87 |
+
/* #llm_result {
|
88 |
+
background-color: #fff;
|
89 |
+
border-radius: 5px;
|
90 |
+
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.2);
|
91 |
+
padding: 10px;
|
92 |
+
}
|
93 |
+
|
94 |
+
#llm_result textarea {
|
95 |
+
border: none;
|
96 |
+
border-radius: 5px;
|
97 |
+
font-size: 16px;
|
98 |
+
height: 100px;
|
99 |
+
margin-top: 10px;
|
100 |
+
padding: 10px;
|
101 |
+
resize: none;
|
102 |
+
width: 100%;
|
103 |
+
} */
|
104 |
+
|
105 |
+
@media only screen and (max-width: 600px) {
|
106 |
+
#content a {
|
107 |
+
width: 100%;
|
108 |
+
}
|
109 |
+
}
|
static/record.js
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
//兼容
|
2 |
+
window.URL = window.URL || window.webkitURL;
|
3 |
+
//获取计算机的设备:摄像头或者录音设备
|
4 |
+
navigator.getUserMedia = navigator.getUserMedia || navigator.webkitGetUserMedia || navigator.mozGetUserMedia || navigator.msGetUserMedia;
|
5 |
+
|
6 |
+
var HZRecorder = function (stream, config) {
|
7 |
+
config = config || {};
|
8 |
+
config.sampleBits = config.sampleBits || 16; //采样数位 8, 16
|
9 |
+
config.sampleRate = config.sampleRate || 16000; //采样率 16000
|
10 |
+
|
11 |
+
//创建一个音频环境对象
|
12 |
+
var audioContext = window.AudioContext || window.webkitAudioContext;
|
13 |
+
var context = new audioContext();
|
14 |
+
var audioInput = context.createMediaStreamSource(stream);
|
15 |
+
// 第二个和第三个参数指的是输入和输出都是单声道,2是双声道。
|
16 |
+
var recorder = context.createScriptProcessor(4096, 2, 2);
|
17 |
+
|
18 |
+
var audioData = {
|
19 |
+
size: 0 //录音文件长度
|
20 |
+
, buffer: [] //录音缓存
|
21 |
+
, inputSampleRate: context.sampleRate //输入采样率
|
22 |
+
, inputSampleBits: 16 //输入采样数位 8, 16
|
23 |
+
, outputSampleRate: config.sampleRate //输出采样率
|
24 |
+
, outputSampleBits: config.sampleBits //输出采样数位 8, 16
|
25 |
+
, input: function (data) {
|
26 |
+
this.buffer.push(new Float32Array(data));
|
27 |
+
this.size += data.length;
|
28 |
+
}
|
29 |
+
, compress: function () { //合并压缩
|
30 |
+
//合并
|
31 |
+
var data = new Float32Array(this.size);
|
32 |
+
var offset = 0;
|
33 |
+
for (var i = 0; i < this.buffer.length; i++) {
|
34 |
+
data.set(this.buffer[i], offset);
|
35 |
+
offset += this.buffer[i].length;
|
36 |
+
}
|
37 |
+
//压缩
|
38 |
+
var compression = parseInt(this.inputSampleRate / this.outputSampleRate);
|
39 |
+
var length = data.length / compression;
|
40 |
+
var result = new Float32Array(length);
|
41 |
+
var index = 0, j = 0;
|
42 |
+
while (index < length) {
|
43 |
+
result[index] = data[j];
|
44 |
+
j += compression;
|
45 |
+
index++;
|
46 |
+
}
|
47 |
+
return result;
|
48 |
+
}
|
49 |
+
, encodeWAV: function () {
|
50 |
+
var sampleRate = Math.min(this.inputSampleRate, this.outputSampleRate);
|
51 |
+
var sampleBits = Math.min(this.inputSampleBits, this.outputSampleBits);
|
52 |
+
var bytes = this.compress();
|
53 |
+
var dataLength = bytes.length * (sampleBits / 8);
|
54 |
+
var buffer = new ArrayBuffer(44 + dataLength);
|
55 |
+
var data = new DataView(buffer);
|
56 |
+
|
57 |
+
var channelCount = 1;//单声道
|
58 |
+
var offset = 0;
|
59 |
+
|
60 |
+
var writeString = function (str) {
|
61 |
+
for (var i = 0; i < str.length; i++) {
|
62 |
+
data.setUint8(offset + i, str.charCodeAt(i));
|
63 |
+
}
|
64 |
+
}
|
65 |
+
|
66 |
+
// 资源交换文件标识符
|
67 |
+
writeString('RIFF');
|
68 |
+
offset += 4;
|
69 |
+
// 下个地址开始到文件尾总字节数,即文件大小-8
|
70 |
+
data.setUint32(offset, 36 + dataLength, true);
|
71 |
+
offset += 4;
|
72 |
+
// WAV文件标志
|
73 |
+
writeString('WAVE');
|
74 |
+
offset += 4;
|
75 |
+
// 波形格式标志
|
76 |
+
writeString('fmt ');
|
77 |
+
offset += 4;
|
78 |
+
// 过滤字节,一般为 0x10 = 16
|
79 |
+
data.setUint32(offset, 16, true);
|
80 |
+
offset += 4;
|
81 |
+
// 格式类别 (PCM形式采样数据)
|
82 |
+
data.setUint16(offset, 1, true);
|
83 |
+
offset += 2;
|
84 |
+
// 通道数
|
85 |
+
data.setUint16(offset, channelCount, true);
|
86 |
+
offset += 2;
|
87 |
+
// 采样率,每秒样本数,表示每个通道的播放速度
|
88 |
+
data.setUint32(offset, sampleRate, true);
|
89 |
+
offset += 4;
|
90 |
+
// 波形数据传输率 (每秒平均字节数) 单声道×每秒数据位数×每样本数据位/8
|
91 |
+
data.setUint32(offset, channelCount * sampleRate * (sampleBits / 8), true);
|
92 |
+
offset += 4;
|
93 |
+
// 快数据调整数 采样一次占用字节数 单声道×每样本的数据位数/8
|
94 |
+
data.setUint16(offset, channelCount * (sampleBits / 8), true);
|
95 |
+
offset += 2;
|
96 |
+
// 每样本数据位数
|
97 |
+
data.setUint16(offset, sampleBits, true);
|
98 |
+
offset += 2;
|
99 |
+
// 数据标识符
|
100 |
+
writeString('data');
|
101 |
+
offset += 4;
|
102 |
+
// 采样数据总数,即数据总大小-44
|
103 |
+
data.setUint32(offset, dataLength, true);
|
104 |
+
offset += 4;
|
105 |
+
// 写入采样数据
|
106 |
+
if (sampleBits === 8) {
|
107 |
+
for (var i = 0; i < bytes.length; i++, offset++) {
|
108 |
+
var s = Math.max(-1, Math.min(1, bytes[i]));
|
109 |
+
var val = s < 0 ? s * 0x8000 : s * 0x7FFF;
|
110 |
+
val = parseInt(255 / (65535 / (val + 32768)));
|
111 |
+
data.setInt8(offset, val, true);
|
112 |
+
}
|
113 |
+
} else {
|
114 |
+
for (var i = 0; i < bytes.length; i++, offset += 2) {
|
115 |
+
var s = Math.max(-1, Math.min(1, bytes[i]));
|
116 |
+
data.setInt16(offset, s < 0 ? s * 0x8000 : s * 0x7FFF, true);
|
117 |
+
}
|
118 |
+
}
|
119 |
+
|
120 |
+
return new Blob([data], {type: 'audio/wav'});
|
121 |
+
}
|
122 |
+
};
|
123 |
+
|
124 |
+
//开始录音
|
125 |
+
this.start = function () {
|
126 |
+
audioInput.connect(recorder);
|
127 |
+
recorder.connect(context.destination);
|
128 |
+
}
|
129 |
+
|
130 |
+
//停止
|
131 |
+
this.stop = function () {
|
132 |
+
recorder.disconnect();
|
133 |
+
}
|
134 |
+
|
135 |
+
//获取音频文件
|
136 |
+
this.getBlob = function () {
|
137 |
+
this.stop();
|
138 |
+
return audioData.encodeWAV();
|
139 |
+
}
|
140 |
+
|
141 |
+
//回放
|
142 |
+
this.play = function (audio) {
|
143 |
+
audio.src = window.URL.createObjectURL(this.getBlob());
|
144 |
+
}
|
145 |
+
//清除
|
146 |
+
this.clear = function () {
|
147 |
+
audioData.buffer = [];
|
148 |
+
audioData.size = 0;
|
149 |
+
}
|
150 |
+
|
151 |
+
//上传
|
152 |
+
this.upload = function (url, callback) {
|
153 |
+
var fd = new FormData();
|
154 |
+
// 上传的文件名和数据
|
155 |
+
fd.append("audio", this.getBlob());
|
156 |
+
var xhr = new XMLHttpRequest();
|
157 |
+
xhr.timeout = 60000
|
158 |
+
if (callback) {
|
159 |
+
xhr.upload.addEventListener("progress", function (e) {
|
160 |
+
callback('uploading', e);
|
161 |
+
}, false);
|
162 |
+
xhr.addEventListener("load", function (e) {
|
163 |
+
callback('ok', e);
|
164 |
+
}, false);
|
165 |
+
xhr.addEventListener("error", function (e) {
|
166 |
+
callback('error', e);
|
167 |
+
}, false);
|
168 |
+
xhr.addEventListener("abort", function (e) {
|
169 |
+
callback('cancel', e);
|
170 |
+
}, false);
|
171 |
+
}
|
172 |
+
xhr.open("POST", url);
|
173 |
+
xhr.send(fd);
|
174 |
+
}
|
175 |
+
|
176 |
+
//音频采集
|
177 |
+
recorder.onaudioprocess = function (e) {
|
178 |
+
audioData.input(e.inputBuffer.getChannelData(0));
|
179 |
+
//record(e.inputBuffer.getChannelData(0));
|
180 |
+
}
|
181 |
+
|
182 |
+
};
|
183 |
+
//抛出异常
|
184 |
+
HZRecorder.throwError = function (message) {
|
185 |
+
alert(message);
|
186 |
+
throw new function () {
|
187 |
+
this.toString = function () {
|
188 |
+
return message;
|
189 |
+
}
|
190 |
+
}
|
191 |
+
}
|
192 |
+
//是否支持录音
|
193 |
+
HZRecorder.canRecording = (navigator.getUserMedia != null);
|
194 |
+
//获取录音机
|
195 |
+
HZRecorder.get = function (callback, config) {
|
196 |
+
if (callback) {
|
197 |
+
if (navigator.getUserMedia) {
|
198 |
+
navigator.getUserMedia(
|
199 |
+
{audio: true} //只启用音频
|
200 |
+
, function (stream) {
|
201 |
+
var rec = new HZRecorder(stream, config);
|
202 |
+
callback(rec);
|
203 |
+
}
|
204 |
+
, function (error) {
|
205 |
+
switch (error.code || error.name) {
|
206 |
+
case 'PERMISSION_DENIED':
|
207 |
+
case 'PermissionDeniedError':
|
208 |
+
HZRecorder.throwError('用户拒绝提供信息。');
|
209 |
+
break;
|
210 |
+
case 'NOT_SUPPORTED_ERROR':
|
211 |
+
case 'NotSupportedError':
|
212 |
+
HZRecorder.throwError('浏览器不支持硬件设备。');
|
213 |
+
break;
|
214 |
+
case 'MANDATORY_UNSATISFIED_ERROR':
|
215 |
+
case 'MandatoryUnsatisfiedError':
|
216 |
+
HZRecorder.throwError('无法发现指定的硬件设备。');
|
217 |
+
break;
|
218 |
+
default:
|
219 |
+
HZRecorder.throwError('无法打开麦克风。异常信息:' + (error.code || error.name));
|
220 |
+
break;
|
221 |
+
}
|
222 |
+
});
|
223 |
+
} else {
|
224 |
+
window.alert('不是HTTPS协议或者localhost地址,不能使用录音功能!')
|
225 |
+
HZRecorder.throwErr('当前浏览器不支持录音功能。');
|
226 |
+
return;
|
227 |
+
}
|
228 |
+
}
|
229 |
+
};
|
static/record.png
ADDED
static/recording.gif
ADDED
templates/index.html
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<title>OdiaGenAI Speech Recognition</title>
|
6 |
+
<script type="text/javascript" src="/static/record.js"></script>
|
7 |
+
<link href="/static/index.css" rel="stylesheet" type="text/css"/>
|
8 |
+
</head>
|
9 |
+
<body>
|
10 |
+
<div id="header">
|
11 |
+
<h1>OdiaGenAI Speech Recognition</h1>
|
12 |
+
</div>
|
13 |
+
<div id="content">
|
14 |
+
<div>
|
15 |
+
<a id="upload" onclick="uploadAudioFile()" class="file">select audio file</a>
|
16 |
+
<a id="play_btn" onclick="uploadRecordAudio()" class="file">predict audio file</a>
|
17 |
+
<audio controls autoplay></audio>
|
18 |
+
<img id="record_btn" onclick="record()" src="/static/record.png" alt="record"/>
|
19 |
+
</div>
|
20 |
+
<div id="result">
|
21 |
+
<label for="result_p"></label><textarea id="result_p"></textarea>
|
22 |
+
</div>
|
23 |
+
<!-- <div id="llm_result">
|
24 |
+
<a id="llm_predict" onclick="uploadAudioFile()" class="file">generate text</a>
|
25 |
+
<label for="result_llm"></label><textarea id="result_llm"></textarea>
|
26 |
+
</div> -->
|
27 |
+
</div>
|
28 |
+
<script>
|
29 |
+
let is_recording = false;
|
30 |
+
let is_playing = false;
|
31 |
+
let host = location.origin;
|
32 |
+
let recorder;
|
33 |
+
let audio = document.querySelector('audio');
|
34 |
+
let textarea = document.getElementById('result_p')
|
35 |
+
|
36 |
+
|
37 |
+
function record() {
|
38 |
+
if (is_recording) {
|
39 |
+
is_recording = false;
|
40 |
+
stopRecording()
|
41 |
+
document.getElementById('record_btn').src = '/static/record.png'
|
42 |
+
startPlay();
|
43 |
+
stopPlay();
|
44 |
+
} else {
|
45 |
+
is_recording = true;
|
46 |
+
startRecording()
|
47 |
+
document.getElementById('record_btn').src = '/static/recording.gif'
|
48 |
+
}
|
49 |
+
}
|
50 |
+
|
51 |
+
function play() {
|
52 |
+
if (is_playing) {
|
53 |
+
is_playing = false;
|
54 |
+
stopPlay()
|
55 |
+
document.getElementById('play_btn').innerText = 'play audio'
|
56 |
+
} else {
|
57 |
+
is_playing = true;
|
58 |
+
startPlay()
|
59 |
+
document.getElementById('play_btn').innerText = 'Stop play'
|
60 |
+
}
|
61 |
+
}
|
62 |
+
|
63 |
+
function startRecording() {
|
64 |
+
HZRecorder.get(function (rec) {
|
65 |
+
recorder = rec;
|
66 |
+
recorder.start();
|
67 |
+
});
|
68 |
+
}
|
69 |
+
|
70 |
+
function stopRecording() {
|
71 |
+
recorder.stop();
|
72 |
+
}
|
73 |
+
|
74 |
+
function startPlay() {
|
75 |
+
recorder.play(audio);
|
76 |
+
}
|
77 |
+
|
78 |
+
function stopPlay() {
|
79 |
+
audio.pause();
|
80 |
+
}
|
81 |
+
|
82 |
+
function cancelAudio() {
|
83 |
+
recorder.stop();
|
84 |
+
recorder.clear();
|
85 |
+
}
|
86 |
+
|
87 |
+
function uploadRecordAudio() {
|
88 |
+
recorder.upload(location.origin + "/recognition", function (state, e) {
|
89 |
+
switch (state) {
|
90 |
+
case 'uploading':
|
91 |
+
const percentComplete = Math.round(e.loaded * 100 / e.total) + '%';
|
92 |
+
console.log(percentComplete);
|
93 |
+
break;
|
94 |
+
case 'ok':
|
95 |
+
console.log(e.target.responseText)
|
96 |
+
document.getElementById('result_p').innerHTML = e.target.responseText
|
97 |
+
break;
|
98 |
+
case 'error':
|
99 |
+
alert("upload failed");
|
100 |
+
break;
|
101 |
+
case 'cancel':
|
102 |
+
alert("upload canceled");
|
103 |
+
break;
|
104 |
+
}
|
105 |
+
});
|
106 |
+
}
|
107 |
+
|
108 |
+
//
|
109 |
+
function uploadAudioFile() {
|
110 |
+
const input = document.createElement("input");
|
111 |
+
input.type = "file";
|
112 |
+
input.accept = "audio/*,video/*";
|
113 |
+
input.click();
|
114 |
+
input.onchange = function () {
|
115 |
+
const file = input.files[0];
|
116 |
+
console.log(file)
|
117 |
+
audio.src = window.URL.createObjectURL(file);
|
118 |
+
stopPlay();
|
119 |
+
upload_file(host + "/recognition", file, function (state, e) {
|
120 |
+
switch (state) {
|
121 |
+
case 'uploading':
|
122 |
+
const percentComplete = Math.round(e.loaded * 100 / e.total) + '%';
|
123 |
+
console.log(percentComplete);
|
124 |
+
break;
|
125 |
+
case 'ok':
|
126 |
+
console.log(e.target.responseText)
|
127 |
+
textarea.innerText = e.target.responseText
|
128 |
+
break;
|
129 |
+
case 'error':
|
130 |
+
alert("upload failed");
|
131 |
+
break;
|
132 |
+
case 'cancel':
|
133 |
+
alert("upload canceled");
|
134 |
+
break;
|
135 |
+
}
|
136 |
+
});
|
137 |
+
}
|
138 |
+
}
|
139 |
+
|
140 |
+
//
|
141 |
+
upload_file = function (url, file, callback) {
|
142 |
+
const fd = new FormData();
|
143 |
+
//
|
144 |
+
fd.append("audio", file);
|
145 |
+
const xhr = new XMLHttpRequest();
|
146 |
+
xhr.timeout = 60000
|
147 |
+
if (callback) {
|
148 |
+
xhr.upload.addEventListener("progress", function (e) {
|
149 |
+
callback('uploading', e);
|
150 |
+
}, false);
|
151 |
+
xhr.addEventListener("load", function (e) {
|
152 |
+
callback('ok', e);
|
153 |
+
}, false);
|
154 |
+
xhr.addEventListener("error", function (e) {
|
155 |
+
callback('error', e);
|
156 |
+
}, false);
|
157 |
+
xhr.addEventListener("abort", function (e) {
|
158 |
+
callback('cancel', e);
|
159 |
+
}, false);
|
160 |
+
}
|
161 |
+
xhr.open("POST", url);
|
162 |
+
xhr.send(fd);
|
163 |
+
}
|
164 |
+
</script>
|
165 |
+
|
166 |
+
</body>
|
167 |
+
</html>
|
utils/__init__.py
ADDED
File without changes
|
utils/binary.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import mmap
|
3 |
+
|
4 |
+
import struct
|
5 |
+
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
|
9 |
+
class DatasetWriter(object):
|
10 |
+
def __init__(self, prefix):
|
11 |
+
#
|
12 |
+
self.data_file = open(prefix + '.data', 'wb')
|
13 |
+
self.header_file = open(prefix + '.header', 'wb')
|
14 |
+
self.data_sum = 0
|
15 |
+
self.offset = 0
|
16 |
+
self.header = ''
|
17 |
+
|
18 |
+
def add_data(self, data):
|
19 |
+
key = str(self.data_sum)
|
20 |
+
data = bytes(data, encoding="utf8")
|
21 |
+
#
|
22 |
+
self.data_file.write(struct.pack('I', len(key)))
|
23 |
+
self.data_file.write(key.encode('ascii'))
|
24 |
+
self.data_file.write(struct.pack('I', len(data)))
|
25 |
+
self.data_file.write(data)
|
26 |
+
#
|
27 |
+
self.offset += 4 + len(key) + 4
|
28 |
+
self.header = key + '\t' + str(self.offset) + '\t' + str(len(data)) + '\n'
|
29 |
+
self.header_file.write(self.header.encode('ascii'))
|
30 |
+
self.offset += len(data)
|
31 |
+
self.data_sum += 1
|
32 |
+
|
33 |
+
def close(self):
|
34 |
+
self.data_file.close()
|
35 |
+
self.header_file.close()
|
36 |
+
|
37 |
+
|
38 |
+
class DatasetReader(object):
|
39 |
+
def __init__(self, data_header_path, min_duration=0, max_duration=30):
|
40 |
+
self.keys = []
|
41 |
+
self.offset_dict = {}
|
42 |
+
self.fp = open(data_header_path.replace('.header', '.data'), 'rb')
|
43 |
+
self.m = mmap.mmap(self.fp.fileno(), 0, access=mmap.ACCESS_READ)
|
44 |
+
for line in tqdm(open(data_header_path, 'rb'), desc='读取数据列表'):
|
45 |
+
key, val_pos, val_len = line.split('\t'.encode('ascii'))
|
46 |
+
data = self.m[int(val_pos):int(val_pos) + int(val_len)]
|
47 |
+
data = str(data, encoding="utf-8")
|
48 |
+
data = json.loads(data)
|
49 |
+
#
|
50 |
+
if data["duration"] < min_duration:
|
51 |
+
continue
|
52 |
+
if max_duration != -1 and data["duration"] > max_duration:
|
53 |
+
continue
|
54 |
+
self.keys.append(key)
|
55 |
+
self.offset_dict[key] = (int(val_pos), int(val_len))
|
56 |
+
|
57 |
+
#
|
58 |
+
def get_data(self, key):
|
59 |
+
p = self.offset_dict.get(key, None)
|
60 |
+
if p is None:
|
61 |
+
return None
|
62 |
+
val_pos, val_len = p
|
63 |
+
data = self.m[val_pos:val_pos + val_len]
|
64 |
+
data = str(data, encoding="utf-8")
|
65 |
+
return json.loads(data)
|
66 |
+
|
67 |
+
#
|
68 |
+
def get_keys(self):
|
69 |
+
return self.keys
|
70 |
+
|
71 |
+
def __len__(self):
|
72 |
+
return len(self.keys)
|
utils/callback.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
|
5 |
+
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
|
6 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
7 |
+
|
8 |
+
|
9 |
+
#
|
10 |
+
class SavePeftModelCallback(TrainerCallback):
|
11 |
+
def on_save(self,
|
12 |
+
args: TrainingArguments,
|
13 |
+
state: TrainerState,
|
14 |
+
control: TrainerControl,
|
15 |
+
**kwargs, ):
|
16 |
+
if args.local_rank == 0 or args.local_rank == -1:
|
17 |
+
#
|
18 |
+
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
19 |
+
peft_model_dir = os.path.join(checkpoint_folder, "adapter_model")
|
20 |
+
kwargs["model"].save_pretrained(peft_model_dir)
|
21 |
+
peft_config_path = os.path.join(checkpoint_folder, "adapter_model/adapter_config.json")
|
22 |
+
peft_model_path = os.path.join(checkpoint_folder, "adapter_model/adapter_model.bin")
|
23 |
+
if not os.path.exists(peft_config_path):
|
24 |
+
os.remove(peft_config_path)
|
25 |
+
if not os.path.exists(peft_model_path):
|
26 |
+
os.remove(peft_model_path)
|
27 |
+
if os.path.exists(peft_model_dir):
|
28 |
+
shutil.rmtree(peft_model_dir)
|
29 |
+
#
|
30 |
+
best_checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-best")
|
31 |
+
#
|
32 |
+
if os.path.exists(state.best_model_checkpoint):
|
33 |
+
if os.path.exists(best_checkpoint_folder):
|
34 |
+
shutil.rmtree(best_checkpoint_folder)
|
35 |
+
shutil.copytree(state.best_model_checkpoint, best_checkpoint_folder)
|
36 |
+
print(f"{state.best_model_checkpoint}{state.best_metric}")
|
37 |
+
return control
|
utils/data_utils.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Any, List, Dict, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from zhconv import convert
|
7 |
+
|
8 |
+
|
9 |
+
# 删除标点符号
|
10 |
+
def remove_punctuation(text: str or List[str]):
|
11 |
+
punctuation = '!,.;:?、!,。;:?'
|
12 |
+
if isinstance(text, str):
|
13 |
+
text = re.sub(r'[{}]+'.format(punctuation), '', text).strip()
|
14 |
+
return text
|
15 |
+
elif isinstance(text, list):
|
16 |
+
result_text = []
|
17 |
+
for t in text:
|
18 |
+
t = re.sub(r'[{}]+'.format(punctuation), '', t).strip()
|
19 |
+
result_text.append(t)
|
20 |
+
return result_text
|
21 |
+
else:
|
22 |
+
raise Exception(f'不支持该类型{type(text)}')
|
23 |
+
|
24 |
+
|
25 |
+
# 将繁体中文总成简体中文
|
26 |
+
def to_simple(text: str or List[str]):
|
27 |
+
if isinstance(text, str):
|
28 |
+
text = convert(text, 'zh-cn')
|
29 |
+
return text
|
30 |
+
elif isinstance(text, list):
|
31 |
+
result_text = []
|
32 |
+
for t in text:
|
33 |
+
t = convert(t, 'zh-cn')
|
34 |
+
result_text.append(t)
|
35 |
+
return result_text
|
36 |
+
else:
|
37 |
+
raise Exception(f'不支持该类型{type(text)}')
|
38 |
+
|
39 |
+
|
40 |
+
@dataclass
|
41 |
+
class DataCollatorSpeechSeq2SeqWithPadding:
|
42 |
+
processor: Any
|
43 |
+
|
44 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
45 |
+
# split inputs and labels since they have to be of different lengths and need different padding methods
|
46 |
+
# first treat the audio inputs by simply returning torch tensors
|
47 |
+
input_features = [{"input_features": feature["input_features"][0]} for feature in features]
|
48 |
+
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
49 |
+
|
50 |
+
# get the tokenized label sequences
|
51 |
+
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
52 |
+
# pad the labels to max length
|
53 |
+
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
|
54 |
+
|
55 |
+
# replace padding with -100 to ignore loss correctly
|
56 |
+
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
57 |
+
|
58 |
+
# if bos token is appended in previous tokenization step,
|
59 |
+
# cut bos token here as it's append later anyways
|
60 |
+
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
|
61 |
+
labels = labels[:, 1:]
|
62 |
+
|
63 |
+
batch["labels"] = labels
|
64 |
+
|
65 |
+
return batch
|
utils/model_utils.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import bitsandbytes as bnb
|
2 |
+
import torch
|
3 |
+
from transformers.trainer_pt_utils import LabelSmoother
|
4 |
+
|
5 |
+
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
6 |
+
|
7 |
+
|
8 |
+
def find_all_linear_names(use_8bit, model):
|
9 |
+
cls = bnb.nn.Linear8bitLt if use_8bit else torch.nn.Linear
|
10 |
+
lora_module_names = set()
|
11 |
+
for name, module in model.named_modules():
|
12 |
+
if isinstance(module, cls):
|
13 |
+
names = name.split('.')
|
14 |
+
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
15 |
+
target_modules = list(lora_module_names)
|
16 |
+
return target_modules
|
17 |
+
|
18 |
+
|
19 |
+
def load_from_checkpoint(resume_from_checkpoint, model=None):
|
20 |
+
pass
|
utils/pun_predictor.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import paddle.inference as paddle_infer
|
7 |
+
from paddlenlp.transformers import ErnieTokenizer
|
8 |
+
|
9 |
+
|
10 |
+
__all__ = ['PunctuationExecutor']
|
11 |
+
|
12 |
+
|
13 |
+
class PunctuationExecutor:
|
14 |
+
def __init__(self, model_dir, use_gpu=True, gpu_mem=500, num_threads=4):
|
15 |
+
# config
|
16 |
+
model_path = os.path.join(model_dir, 'model.pdmodel')
|
17 |
+
params_path = os.path.join(model_dir, 'model.pdiparams')
|
18 |
+
if not os.path.exists(model_path) or not os.path.exists(params_path):
|
19 |
+
raise Exception("{}{}".format(model_path, params_path))
|
20 |
+
self.config = paddle_infer.Config(model_path, params_path)
|
21 |
+
#
|
22 |
+
pretrained_token = 'ernie-1.0'
|
23 |
+
if os.path.exists(os.path.join(model_dir, 'info.json')):
|
24 |
+
with open(os.path.join(model_dir, 'info.json'), 'r', encoding='utf-8') as f:
|
25 |
+
data = json.load(f)
|
26 |
+
pretrained_token = data['pretrained_token']
|
27 |
+
|
28 |
+
if use_gpu:
|
29 |
+
self.config.enable_use_gpu(gpu_mem, 0)
|
30 |
+
else:
|
31 |
+
self.config.disable_gpu()
|
32 |
+
self.config.set_cpu_math_library_num_threads(num_threads)
|
33 |
+
# enable memory optim
|
34 |
+
self.config.enable_memory_optim()
|
35 |
+
self.config.disable_glog_info()
|
36 |
+
|
37 |
+
# config predictor
|
38 |
+
self.predictor = paddle_infer.create_predictor(self.config)
|
39 |
+
|
40 |
+
#
|
41 |
+
self.input_ids_handle = self.predictor.get_input_handle('input_ids')
|
42 |
+
self.token_type_ids_handle = self.predictor.get_input_handle('token_type_ids')
|
43 |
+
|
44 |
+
#
|
45 |
+
self.output_names = self.predictor.get_output_names()
|
46 |
+
|
47 |
+
self._punc_list = []
|
48 |
+
if not os.path.join(model_dir, 'vocab.txt'):
|
49 |
+
raise Exception("{}".format(os.path.join(model_dir, 'vocab.txt')))
|
50 |
+
with open(os.path.join(model_dir, 'vocab.txt'), 'r', encoding='utf-8') as f:
|
51 |
+
for line in f:
|
52 |
+
self._punc_list.append(line.strip())
|
53 |
+
|
54 |
+
self.tokenizer = ErnieTokenizer.from_pretrained(pretrained_token)
|
55 |
+
|
56 |
+
#
|
57 |
+
self('')
|
58 |
+
|
59 |
+
def _clean_text(self, text):
|
60 |
+
text = text.lower()
|
61 |
+
text = re.sub('[^A-Za-z0-9\u4e00-\u9fa5]', '', text)
|
62 |
+
text = re.sub(f'[{"".join([p for p in self._punc_list][1:])}]', '', text)
|
63 |
+
return text
|
64 |
+
|
65 |
+
#
|
66 |
+
def preprocess(self, text: str):
|
67 |
+
clean_text = self._clean_text(text)
|
68 |
+
if len(clean_text) == 0: return None
|
69 |
+
tokenized_input = self.tokenizer(list(clean_text), return_length=True, is_split_into_words=True)
|
70 |
+
input_ids = tokenized_input['input_ids']
|
71 |
+
seg_ids = tokenized_input['token_type_ids']
|
72 |
+
seq_len = tokenized_input['seq_len']
|
73 |
+
return input_ids, seg_ids, seq_len
|
74 |
+
|
75 |
+
def infer(self, input_ids: list, seg_ids: list):
|
76 |
+
#
|
77 |
+
self.input_ids_handle.reshape([1, len(input_ids)])
|
78 |
+
self.token_type_ids_handle.reshape([1, len(seg_ids)])
|
79 |
+
self.input_ids_handle.copy_from_cpu(np.array([input_ids]).astype('int64'))
|
80 |
+
self.token_type_ids_handle.copy_from_cpu(np.array([seg_ids]).astype('int64'))
|
81 |
+
|
82 |
+
# predictor
|
83 |
+
self.predictor.run()
|
84 |
+
|
85 |
+
#
|
86 |
+
output_handle = self.predictor.get_output_handle(self.output_names[0])
|
87 |
+
output_data = output_handle.copy_to_cpu()
|
88 |
+
return output_data
|
89 |
+
|
90 |
+
#
|
91 |
+
def postprocess(self, input_ids, seq_len, preds):
|
92 |
+
tokens = self.tokenizer.convert_ids_to_tokens(input_ids[1:seq_len - 1])
|
93 |
+
labels = preds[1:seq_len - 1].tolist()
|
94 |
+
assert len(tokens) == len(labels)
|
95 |
+
|
96 |
+
text = ''
|
97 |
+
for t, l in zip(tokens, labels):
|
98 |
+
text += t
|
99 |
+
if l != 0:
|
100 |
+
text += self._punc_list[l]
|
101 |
+
return text
|
102 |
+
|
103 |
+
def __call__(self, text: str) -> str:
|
104 |
+
#
|
105 |
+
input_ids, seg_ids, seq_len = self.preprocess(text)
|
106 |
+
preds = self.infer(input_ids=input_ids, seg_ids=seg_ids)
|
107 |
+
if len(preds.shape) == 2:
|
108 |
+
preds = preds[0]
|
109 |
+
text = self.postprocess(input_ids, seq_len, preds)
|
110 |
+
return text
|
utils/reader.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import sys
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
import librosa
|
8 |
+
import numpy as np
|
9 |
+
import soundfile
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
from utils.binary import DatasetReader
|
14 |
+
|
15 |
+
|
16 |
+
class CustomDataset(Dataset):
|
17 |
+
def __init__(self,
|
18 |
+
data_list_path,
|
19 |
+
processor,
|
20 |
+
mono=True,
|
21 |
+
language=None,
|
22 |
+
timestamps=False,
|
23 |
+
sample_rate=16000,
|
24 |
+
min_duration=0.5,
|
25 |
+
max_duration=30,
|
26 |
+
augment_config_path=None):
|
27 |
+
"""
|
28 |
+
Args:
|
29 |
+
data_list_path:
|
30 |
+
processor: Whisper
|
31 |
+
mono: True
|
32 |
+
language:
|
33 |
+
timestamps:
|
34 |
+
sample_rate: 16000
|
35 |
+
min_duration: 0.5s
|
36 |
+
max_duration: 30s
|
37 |
+
augment_config_path:
|
38 |
+
"""
|
39 |
+
super(CustomDataset, self).__init__()
|
40 |
+
assert min_duration >= 0.5, f"min_duration 0.5:{min_duration}"
|
41 |
+
assert max_duration <= 30, f"max_duration 30:{max_duration}"
|
42 |
+
self.data_list_path = data_list_path
|
43 |
+
self.processor = processor
|
44 |
+
self.data_list_path = data_list_path
|
45 |
+
self.sample_rate = sample_rate
|
46 |
+
self.mono = mono
|
47 |
+
self.language = language
|
48 |
+
self.timestamps = timestamps
|
49 |
+
self.min_duration = min_duration
|
50 |
+
self.max_duration = max_duration
|
51 |
+
self.vocab = self.processor.tokenizer.get_vocab()
|
52 |
+
self.timestamp_begin = self.vocab['<|notimestamps|>'] + 1
|
53 |
+
self.startoftranscript = self.vocab['<|startoftranscript|>']
|
54 |
+
self.endoftext = self.vocab['<|endoftext|>']
|
55 |
+
self.nocaptions = self.vocab['<|nocaptions|>']
|
56 |
+
self.data_list: List[dict] = []
|
57 |
+
#
|
58 |
+
self._load_data_list()
|
59 |
+
#
|
60 |
+
self.augment_configs = None
|
61 |
+
self.noises_path = None
|
62 |
+
self.speed_rates = None
|
63 |
+
if augment_config_path:
|
64 |
+
with open(augment_config_path, 'r', encoding='utf-8') as f:
|
65 |
+
self.augment_configs = json.load(f)
|
66 |
+
|
67 |
+
#
|
68 |
+
def _load_data_list(self):
|
69 |
+
if self.data_list_path.endswith(".header"):
|
70 |
+
#
|
71 |
+
self.dataset_reader = DatasetReader(data_header_path=self.data_list_path,
|
72 |
+
min_duration=self.min_duration,
|
73 |
+
max_duration=self.max_duration)
|
74 |
+
self.data_list = self.dataset_reader.get_keys()
|
75 |
+
else:
|
76 |
+
#
|
77 |
+
with open(self.data_list_path, 'r', encoding='utf-8') as f:
|
78 |
+
lines = f.readlines()
|
79 |
+
self.data_list = []
|
80 |
+
for line in tqdm(lines, desc=''):
|
81 |
+
if isinstance(line, str):
|
82 |
+
line = json.loads(line)
|
83 |
+
if not isinstance(line, dict): continue
|
84 |
+
#
|
85 |
+
if line["duration"] < self.min_duration:
|
86 |
+
continue
|
87 |
+
if self.max_duration != -1 and line["duration"] > self.max_duration:
|
88 |
+
continue
|
89 |
+
self.data_list.append(dict(line))
|
90 |
+
|
91 |
+
#
|
92 |
+
def _get_list_data(self, idx):
|
93 |
+
if self.data_list_path.endswith(".header"):
|
94 |
+
data_list = self.dataset_reader.get_data(self.data_list[idx])
|
95 |
+
else:
|
96 |
+
data_list = self.data_list[idx]
|
97 |
+
#
|
98 |
+
audio_file = data_list["audio"]['path']
|
99 |
+
transcript = data_list["sentences"] if self.timestamps else data_list["sentence"]
|
100 |
+
language = data_list["language"] if 'language' in data_list.keys() else None
|
101 |
+
if 'start_time' not in data_list["audio"].keys():
|
102 |
+
sample, sample_rate = soundfile.read(audio_file, dtype='float32')
|
103 |
+
else:
|
104 |
+
start_time, end_time = data_list["audio"]["start_time"], data_list["audio"]["end_time"]
|
105 |
+
#
|
106 |
+
sample, sample_rate = self.slice_from_file(audio_file, start=start_time, end=end_time)
|
107 |
+
sample = sample.T
|
108 |
+
#
|
109 |
+
if self.mono:
|
110 |
+
sample = librosa.to_mono(sample)
|
111 |
+
#
|
112 |
+
if self.augment_configs:
|
113 |
+
sample, sample_rate = self.augment(sample, sample_rate)
|
114 |
+
#
|
115 |
+
if self.sample_rate != sample_rate:
|
116 |
+
sample = self.resample(sample, orig_sr=sample_rate, target_sr=self.sample_rate)
|
117 |
+
return sample, sample_rate, transcript, language
|
118 |
+
|
119 |
+
def _load_timestamps_transcript(self, transcript: List[dict]):
|
120 |
+
assert isinstance(transcript, list), f"transcript list:{type(transcript)}"
|
121 |
+
data = dict()
|
122 |
+
labels = self.processor.tokenizer.prefix_tokens[:3]
|
123 |
+
for t in transcript:
|
124 |
+
#
|
125 |
+
start = t['start'] if round(t['start'] * 100) % 2 == 0 else t['start'] + 0.01
|
126 |
+
start = self.timestamp_begin + round(start * 100) // 2
|
127 |
+
end = t['end'] if round(t['end'] * 100) % 2 == 0 else t['end'] - 0.01
|
128 |
+
end = self.timestamp_begin + round(end * 100) // 2
|
129 |
+
label = self.processor(text=t['text']).input_ids[4:-1]
|
130 |
+
labels.extend([start])
|
131 |
+
labels.extend(label)
|
132 |
+
labels.extend([end])
|
133 |
+
data['labels'] = labels + [self.endoftext]
|
134 |
+
return data
|
135 |
+
|
136 |
+
def __getitem__(self, idx):
|
137 |
+
try:
|
138 |
+
#
|
139 |
+
sample, sample_rate, transcript, language = self._get_list_data(idx=idx)
|
140 |
+
#
|
141 |
+
self.processor.tokenizer.set_prefix_tokens(language=language if language is not None else self.language)
|
142 |
+
if len(transcript) > 0:
|
143 |
+
#
|
144 |
+
if self.timestamps:
|
145 |
+
data = self._load_timestamps_transcript(transcript=transcript)
|
146 |
+
#
|
147 |
+
data["input_features"] = self.processor(audio=sample, sampling_rate=self.sample_rate).input_features
|
148 |
+
else:
|
149 |
+
#
|
150 |
+
data = self.processor(audio=sample, sampling_rate=self.sample_rate, text=transcript)
|
151 |
+
else:
|
152 |
+
#
|
153 |
+
data = self.processor(audio=sample, sampling_rate=self.sample_rate)
|
154 |
+
data['labels'] = [self.startoftranscript, self.nocaptions, self.endoftext]
|
155 |
+
return data
|
156 |
+
except Exception as e:
|
157 |
+
print(f'idx:{idx} error - {e}', file=sys.stderr)
|
158 |
+
return self.__getitem__(random.randint(0, self.__len__() - 1))
|
159 |
+
|
160 |
+
def __len__(self):
|
161 |
+
return len(self.data_list)
|
162 |
+
|
163 |
+
#
|
164 |
+
@staticmethod
|
165 |
+
def slice_from_file(file, start, end):
|
166 |
+
sndfile = soundfile.SoundFile(file)
|
167 |
+
sample_rate = sndfile.samplerate
|
168 |
+
duration = round(float(len(sndfile)) / sample_rate, 3)
|
169 |
+
start = round(start, 3)
|
170 |
+
end = round(end, 3)
|
171 |
+
#
|
172 |
+
if start < 0.0: start += duration
|
173 |
+
if end < 0.0: end += duration
|
174 |
+
#
|
175 |
+
if start < 0.0: start = 0.0
|
176 |
+
if end > duration: end = duration
|
177 |
+
if end < 0.0:
|
178 |
+
raise ValueError("(%f s)" % end)
|
179 |
+
if start > end:
|
180 |
+
raise ValueError("(%f s)(%f s)" % (start, end))
|
181 |
+
start_frame = int(start * sample_rate)
|
182 |
+
end_frame = int(end * sample_rate)
|
183 |
+
sndfile.seek(start_frame)
|
184 |
+
sample = sndfile.read(frames=end_frame - start_frame, dtype='float32')
|
185 |
+
return sample, sample_rate
|
186 |
+
|
187 |
+
#
|
188 |
+
def augment(self, sample, sample_rate):
|
189 |
+
for config in self.augment_configs:
|
190 |
+
if config['type'] == 'speed' and random.random() < config['prob']:
|
191 |
+
if self.speed_rates is None:
|
192 |
+
min_speed_rate, max_speed_rate, num_rates = config['params']['min_speed_rate'], \
|
193 |
+
config['params']['max_speed_rate'], config['params']['num_rates']
|
194 |
+
self.speed_rates = np.linspace(min_speed_rate, max_speed_rate, num_rates, endpoint=True)
|
195 |
+
rate = random.choice(self.speed_rates)
|
196 |
+
sample = self.change_speed(sample, speed_rate=rate)
|
197 |
+
if config['type'] == 'shift' and random.random() < config['prob']:
|
198 |
+
min_shift_ms, max_shift_ms = config['params']['min_shift_ms'], config['params']['max_shift_ms']
|
199 |
+
shift_ms = random.randint(min_shift_ms, max_shift_ms)
|
200 |
+
sample = self.shift(sample, sample_rate, shift_ms=shift_ms)
|
201 |
+
if config['type'] == 'volume' and random.random() < config['prob']:
|
202 |
+
min_gain_dBFS, max_gain_dBFS = config['params']['min_gain_dBFS'], config['params']['max_gain_dBFS']
|
203 |
+
gain = random.randint(min_gain_dBFS, max_gain_dBFS)
|
204 |
+
sample = self.volume(sample, gain=gain)
|
205 |
+
if config['type'] == 'resample' and random.random() < config['prob']:
|
206 |
+
new_sample_rates = config['params']['new_sample_rates']
|
207 |
+
new_sample_rate = np.random.choice(new_sample_rates)
|
208 |
+
sample = self.resample(sample, orig_sr=sample_rate, target_sr=new_sample_rate)
|
209 |
+
sample_rate = new_sample_rate
|
210 |
+
if config['type'] == 'noise' and random.random() < config['prob']:
|
211 |
+
min_snr_dB, max_snr_dB = config['params']['min_snr_dB'], config['params']['max_snr_dB']
|
212 |
+
if self.noises_path is None:
|
213 |
+
self.noises_path = []
|
214 |
+
noise_dir = config['params']['noise_dir']
|
215 |
+
if os.path.exists(noise_dir):
|
216 |
+
for file in os.listdir(noise_dir):
|
217 |
+
self.noises_path.append(os.path.join(noise_dir, file))
|
218 |
+
noise_path = random.choice(self.noises_path)
|
219 |
+
snr_dB = random.randint(min_snr_dB, max_snr_dB)
|
220 |
+
sample = self.add_noise(sample, sample_rate, noise_path=noise_path, snr_dB=snr_dB)
|
221 |
+
return sample, sample_rate
|
222 |
+
|
223 |
+
#
|
224 |
+
@staticmethod
|
225 |
+
def change_speed(sample, speed_rate):
|
226 |
+
if speed_rate == 1.0:
|
227 |
+
return sample
|
228 |
+
if speed_rate <= 0:
|
229 |
+
raise ValueError("error")
|
230 |
+
old_length = sample.shape[0]
|
231 |
+
new_length = int(old_length / speed_rate)
|
232 |
+
old_indices = np.arange(old_length)
|
233 |
+
new_indices = np.linspace(start=0, stop=old_length, num=new_length)
|
234 |
+
sample = np.interp(new_indices, old_indices, sample).astype(np.float32)
|
235 |
+
return sample
|
236 |
+
|
237 |
+
#
|
238 |
+
@staticmethod
|
239 |
+
def shift(sample, sample_rate, shift_ms):
|
240 |
+
duration = sample.shape[0] / sample_rate
|
241 |
+
if abs(shift_ms) / 1000.0 > duration:
|
242 |
+
raise ValueError("shift_ms")
|
243 |
+
shift_samples = int(shift_ms * sample_rate / 1000)
|
244 |
+
if shift_samples > 0:
|
245 |
+
sample[:-shift_samples] = sample[shift_samples:]
|
246 |
+
sample[-shift_samples:] = 0
|
247 |
+
elif shift_samples < 0:
|
248 |
+
sample[-shift_samples:] = sample[:shift_samples]
|
249 |
+
sample[:-shift_samples] = 0
|
250 |
+
return sample
|
251 |
+
|
252 |
+
#
|
253 |
+
@staticmethod
|
254 |
+
def volume(sample, gain):
|
255 |
+
sample *= 10.**(gain / 20.)
|
256 |
+
return
|
257 |
+
|
258 |
+
#
|
259 |
+
@staticmethod
|
260 |
+
def resample(sample, orig_sr, target_sr):
|
261 |
+
sample = librosa.resample(sample, orig_sr=orig_sr, target_sr=target_sr)
|
262 |
+
return sample
|
263 |
+
|
264 |
+
#
|
265 |
+
def add_noise(self, sample, sample_rate, noise_path, snr_dB, max_gain_db=300.0):
|
266 |
+
noise_sample, sr = librosa.load(noise_path, sr=sample_rate)
|
267 |
+
#
|
268 |
+
target_db = -20
|
269 |
+
gain = min(max_gain_db, target_db - self.rms_db(sample))
|
270 |
+
sample *= 10. ** (gain / 20.)
|
271 |
+
#
|
272 |
+
sample_rms_db, noise_rms_db = self.rms_db(sample), self.rms_db(noise_sample)
|
273 |
+
noise_gain_db = min(sample_rms_db - noise_rms_db - snr_dB, max_gain_db)
|
274 |
+
noise_sample *= 10. ** (noise_gain_db / 20.)
|
275 |
+
#
|
276 |
+
if noise_sample.shape[0] < sample.shape[0]:
|
277 |
+
diff_duration = sample.shape[0] - noise_sample.shape[0]
|
278 |
+
noise_sample = np.pad(noise_sample, (0, diff_duration), 'wrap')
|
279 |
+
elif noise_sample.shape[0] > sample.shape[0]:
|
280 |
+
start_frame = random.randint(0, noise_sample.shape[0] - sample.shape[0])
|
281 |
+
noise_sample = noise_sample[start_frame:sample.shape[0] + start_frame]
|
282 |
+
sample += noise_sample
|
283 |
+
return sample
|
284 |
+
|
285 |
+
@staticmethod
|
286 |
+
def rms_db(sample):
|
287 |
+
mean_square = np.mean(sample ** 2)
|
288 |
+
return 10 * np.log10(mean_square)
|
289 |
+
|
utils/utils.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
import tarfile
|
4 |
+
import urllib.request
|
5 |
+
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
|
9 |
+
def print_arguments(args):
|
10 |
+
print("----------- Configuration Arguments -----------")
|
11 |
+
for arg, value in vars(args).items():
|
12 |
+
print("%s: %s" % (arg, value))
|
13 |
+
print("------------------------------------------------")
|
14 |
+
|
15 |
+
|
16 |
+
def strtobool(val):
|
17 |
+
val = val.lower()
|
18 |
+
if val in ('y', 'yes', 't', 'true', 'on', '1'):
|
19 |
+
return True
|
20 |
+
elif val in ('n', 'no', 'f', 'false', 'off', '0'):
|
21 |
+
return False
|
22 |
+
else:
|
23 |
+
raise ValueError("invalid truth value %r" % (val,))
|
24 |
+
|
25 |
+
|
26 |
+
def str_none(val):
|
27 |
+
if val == 'None':
|
28 |
+
return None
|
29 |
+
else:
|
30 |
+
return val
|
31 |
+
|
32 |
+
|
33 |
+
def add_arguments(argname, type, default, help, argparser, **kwargs):
|
34 |
+
type = strtobool if type == bool else type
|
35 |
+
type = str_none if type == str else type
|
36 |
+
argparser.add_argument("--" + argname,
|
37 |
+
default=default,
|
38 |
+
type=type,
|
39 |
+
help=help + ' Default: %(default)s.',
|
40 |
+
**kwargs)
|
41 |
+
|
42 |
+
|
43 |
+
def md5file(fname):
|
44 |
+
hash_md5 = hashlib.md5()
|
45 |
+
f = open(fname, "rb")
|
46 |
+
for chunk in iter(lambda: f.read(4096), b""):
|
47 |
+
hash_md5.update(chunk)
|
48 |
+
f.close()
|
49 |
+
return hash_md5.hexdigest()
|
50 |
+
|
51 |
+
|
52 |
+
def download(url, md5sum, target_dir):
|
53 |
+
"""Download file from url to target_dir, and check md5sum."""
|
54 |
+
if not os.path.exists(target_dir): os.makedirs(target_dir)
|
55 |
+
filepath = os.path.join(target_dir, url.split("/")[-1])
|
56 |
+
if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
|
57 |
+
print(f"Downloading {url} to {filepath} ...")
|
58 |
+
with urllib.request.urlopen(url) as source, open(filepath, "wb") as output:
|
59 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True,
|
60 |
+
unit_divisor=1024) as loop:
|
61 |
+
while True:
|
62 |
+
buffer = source.read(8192)
|
63 |
+
if not buffer:
|
64 |
+
break
|
65 |
+
|
66 |
+
output.write(buffer)
|
67 |
+
loop.update(len(buffer))
|
68 |
+
print(f"\nMD5 Chesksum {filepath} ...")
|
69 |
+
if not md5file(filepath) == md5sum:
|
70 |
+
raise RuntimeError("MD5 checksum failed.")
|
71 |
+
else:
|
72 |
+
print(f"File exists, skip downloading. ({filepath})")
|
73 |
+
return filepath
|
74 |
+
|
75 |
+
|
76 |
+
def unpack(filepath, target_dir, rm_tar=False):
|
77 |
+
"""Unpack the file to the target_dir."""
|
78 |
+
print("Unpacking %s ..." % filepath)
|
79 |
+
tar = tarfile.open(filepath)
|
80 |
+
tar.extractall(target_dir)
|
81 |
+
tar.close()
|
82 |
+
if rm_tar:
|
83 |
+
os.remove(filepath)
|
84 |
+
|
85 |
+
|
86 |
+
def make_inputs_require_grad(module, input, output):
|
87 |
+
output.requires_grad_(True)
|