HoneyTian
commited on
Commit
•
69ad385
0
Parent(s):
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +35 -0
- .gitignore +18 -0
- Dockerfile +21 -0
- README.md +11 -0
- examples/vm_sound_classification/conv2d_classifier.yaml +45 -0
- examples/vm_sound_classification/requirements.txt +10 -0
- examples/vm_sound_classification/run.sh +188 -0
- examples/vm_sound_classification/step_1_prepare_data.py +150 -0
- examples/vm_sound_classification/step_2_make_vocabulary.py +51 -0
- examples/vm_sound_classification/step_3_train_model.py +331 -0
- examples/vm_sound_classification/step_4_evaluation_model.py +128 -0
- examples/vm_sound_classification/step_5_export_models.py +106 -0
- examples/vm_sound_classification/step_6_infer.py +91 -0
- examples/vm_sound_classification/step_7_test_model.py +93 -0
- examples/vm_sound_classification/stop.sh +3 -0
- examples/vm_sound_classification8/requirements.txt +9 -0
- examples/vm_sound_classification8/run.sh +157 -0
- examples/vm_sound_classification8/step_1_prepare_data.py +156 -0
- examples/vm_sound_classification8/step_2_make_vocabulary.py +69 -0
- examples/vm_sound_classification8/step_3_train_global_model.py +328 -0
- examples/vm_sound_classification8/step_4_train_country_model.py +349 -0
- examples/vm_sound_classification8/step_5_train_union.py +499 -0
- examples/vm_sound_classification8/stop.sh +3 -0
- install.sh +64 -0
- main.py +172 -0
- project_settings.py +19 -0
- requirements.txt +12 -0
- script/install_nvidia_driver.sh +184 -0
- script/install_python.sh +129 -0
- toolbox/__init__.py +5 -0
- toolbox/json/__init__.py +6 -0
- toolbox/json/misc.py +63 -0
- toolbox/os/__init__.py +6 -0
- toolbox/os/command.py +59 -0
- toolbox/os/environment.py +114 -0
- toolbox/os/other.py +9 -0
- toolbox/torch/__init__.py +5 -0
- toolbox/torch/modules/__init__.py +6 -0
- toolbox/torch/modules/gaussian_mixture.py +173 -0
- toolbox/torch/modules/highway.py +30 -0
- toolbox/torch/modules/loss.py +738 -0
- toolbox/torch/training/__init__.py +6 -0
- toolbox/torch/training/metrics/__init__.py +6 -0
- toolbox/torch/training/metrics/categorical_accuracy.py +82 -0
- toolbox/torch/training/metrics/verbose_categorical_accuracy.py +128 -0
- toolbox/torch/training/trainer/__init__.py +5 -0
- toolbox/torch/training/trainer/trainer.py +5 -0
- toolbox/torch/utils/__init__.py +5 -0
- toolbox/torchaudio/__init__.py +5 -0
- toolbox/torchaudio/configuration_utils.py +63 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
.git/
|
3 |
+
.idea/
|
4 |
+
|
5 |
+
**/file_dir
|
6 |
+
**/flagged/
|
7 |
+
**/log/
|
8 |
+
**/logs/
|
9 |
+
**/__pycache__/
|
10 |
+
|
11 |
+
data/
|
12 |
+
docs/
|
13 |
+
dotenv/
|
14 |
+
trained_models/
|
15 |
+
temp/
|
16 |
+
|
17 |
+
#**/*.wav
|
18 |
+
**/*.xlsx
|
Dockerfile
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.8
|
2 |
+
|
3 |
+
WORKDIR /code
|
4 |
+
|
5 |
+
COPY . /code
|
6 |
+
|
7 |
+
RUN pip install --upgrade pip
|
8 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
9 |
+
|
10 |
+
RUN useradd -m -u 1000 user
|
11 |
+
|
12 |
+
USER user
|
13 |
+
|
14 |
+
ENV HOME=/home/user \
|
15 |
+
PATH=/home/user/.local/bin:$PATH
|
16 |
+
|
17 |
+
WORKDIR $HOME/app
|
18 |
+
|
19 |
+
COPY --chown=user . $HOME/app
|
20 |
+
|
21 |
+
CMD ["python3", "main.py"]
|
README.md
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: VM Sound Classification
|
3 |
+
emoji: 🐢
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: blue
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
license: apache-2.0
|
9 |
+
---
|
10 |
+
|
11 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
examples/vm_sound_classification/conv2d_classifier.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "cnn_audio_classifier"
|
2 |
+
|
3 |
+
mel_spectrogram_param:
|
4 |
+
sample_rate: 8000
|
5 |
+
n_fft: 512
|
6 |
+
win_length: 200
|
7 |
+
hop_length: 80
|
8 |
+
f_min: 10
|
9 |
+
f_max: 3800
|
10 |
+
window_fn: hamming
|
11 |
+
n_mels: 80
|
12 |
+
|
13 |
+
conv2d_block_param_list:
|
14 |
+
- batch_norm: true
|
15 |
+
in_channels: 1
|
16 |
+
out_channels: 16
|
17 |
+
kernel_size: 3
|
18 |
+
stride: 1
|
19 |
+
dilation: 3
|
20 |
+
activation: relu
|
21 |
+
dropout: 0.1
|
22 |
+
- in_channels: 16
|
23 |
+
out_channels: 16
|
24 |
+
kernel_size: 5
|
25 |
+
stride: 2
|
26 |
+
dilation: 3
|
27 |
+
activation: relu
|
28 |
+
dropout: 0.1
|
29 |
+
- in_channels: 16
|
30 |
+
out_channels: 16
|
31 |
+
kernel_size: 3
|
32 |
+
stride: 1
|
33 |
+
dilation: 2
|
34 |
+
activation: relu
|
35 |
+
dropout: 0.1
|
36 |
+
|
37 |
+
cls_head_param:
|
38 |
+
input_dim: 432
|
39 |
+
num_layers: 2
|
40 |
+
hidden_dims:
|
41 |
+
- 128
|
42 |
+
- 32
|
43 |
+
activations: relu
|
44 |
+
dropout: 0.1
|
45 |
+
num_labels: 8
|
examples/vm_sound_classification/requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.13.1
|
2 |
+
torchaudio==0.13.1
|
3 |
+
fsspec==2022.1.0
|
4 |
+
librosa==0.9.2
|
5 |
+
pandas==1.1.5
|
6 |
+
openpyxl==3.0.9
|
7 |
+
xlrd==1.2.0
|
8 |
+
tqdm==4.64.1
|
9 |
+
overrides==1.9.0
|
10 |
+
pyyaml==6.0.1
|
examples/vm_sound_classification/run.sh
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
: <<'END'
|
4 |
+
|
5 |
+
sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification3 \
|
6 |
+
--filename_patterns "E:/programmer/asr_datasets/voicemail/wav_finished/en-US/wav_finished/*/*.wav \
|
7 |
+
E:/programmer/asr_datasets/voicemail/wav_finished/id-ID/wav_finished/*/*.wav" \
|
8 |
+
|
9 |
+
sh run.sh --stage 0 --stop_stage 1 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification3 \
|
10 |
+
--filename_patterns "E:/programmer/asr_datasets/voicemail/wav_finished/en-US/wav_finished/*/*.wav \
|
11 |
+
E:/programmer/asr_datasets/voicemail/wav_finished/id-ID/wav_finished/*/*.wav" \
|
12 |
+
|
13 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification3
|
14 |
+
sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification3
|
15 |
+
|
16 |
+
sh run.sh --stage 3 --stop_stage 5 --system_version centos --file_folder_name file_dir --final_model_name vm_sound_classification8-ch16 \
|
17 |
+
--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
|
18 |
+
|
19 |
+
|
20 |
+
"
|
21 |
+
|
22 |
+
END
|
23 |
+
|
24 |
+
|
25 |
+
# params
|
26 |
+
system_version="windows";
|
27 |
+
verbose=true;
|
28 |
+
stage=0 # start from 0 if you need to start from data preparation
|
29 |
+
stop_stage=9
|
30 |
+
|
31 |
+
work_dir="$(pwd)"
|
32 |
+
file_folder_name=file_folder_name
|
33 |
+
final_model_name=final_model_name
|
34 |
+
filename_patterns="/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
|
35 |
+
nohup_name=nohup.out
|
36 |
+
|
37 |
+
country=en-US
|
38 |
+
|
39 |
+
# model params
|
40 |
+
batch_size=64
|
41 |
+
max_epochs=200
|
42 |
+
save_top_k=10
|
43 |
+
patience=5
|
44 |
+
|
45 |
+
|
46 |
+
# parse options
|
47 |
+
while true; do
|
48 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
49 |
+
case "$1" in
|
50 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
51 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
52 |
+
old_value="(eval echo \\$$name)";
|
53 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
54 |
+
was_bool=true;
|
55 |
+
else
|
56 |
+
was_bool=false;
|
57 |
+
fi
|
58 |
+
|
59 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
60 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
61 |
+
eval "${name}=\"$2\"";
|
62 |
+
|
63 |
+
# Check that Boolean-valued arguments are really Boolean.
|
64 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
65 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
66 |
+
exit 1;
|
67 |
+
fi
|
68 |
+
shift 2;
|
69 |
+
;;
|
70 |
+
|
71 |
+
*) break;
|
72 |
+
esac
|
73 |
+
done
|
74 |
+
|
75 |
+
file_dir="${work_dir}/${file_folder_name}"
|
76 |
+
final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
|
77 |
+
|
78 |
+
dataset="${file_dir}/dataset.xlsx"
|
79 |
+
train_dataset="${file_dir}/train.xlsx"
|
80 |
+
valid_dataset="${file_dir}/valid.xlsx"
|
81 |
+
evaluation_file="${file_dir}/evaluation.xlsx"
|
82 |
+
vocabulary_dir="${file_dir}/vocabulary"
|
83 |
+
|
84 |
+
$verbose && echo "system_version: ${system_version}"
|
85 |
+
$verbose && echo "file_folder_name: ${file_folder_name}"
|
86 |
+
|
87 |
+
if [ $system_version == "windows" ]; then
|
88 |
+
alias python3='D:/Users/tianx/PycharmProjects/virtualenv/vm_sound_classification/Scripts/python.exe'
|
89 |
+
elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
|
90 |
+
#source /data/local/bin/vm_sound_classification/bin/activate
|
91 |
+
alias python3='/data/local/bin/vm_sound_classification/bin/python3'
|
92 |
+
fi
|
93 |
+
|
94 |
+
|
95 |
+
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
96 |
+
$verbose && echo "stage 0: prepare data"
|
97 |
+
cd "${work_dir}" || exit 1
|
98 |
+
python3 step_1_prepare_data.py \
|
99 |
+
--file_dir "${file_dir}" \
|
100 |
+
--filename_patterns "${filename_patterns}" \
|
101 |
+
--train_dataset "${train_dataset}" \
|
102 |
+
--valid_dataset "${valid_dataset}" \
|
103 |
+
|
104 |
+
fi
|
105 |
+
|
106 |
+
|
107 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
108 |
+
$verbose && echo "stage 1: make vocabulary"
|
109 |
+
cd "${work_dir}" || exit 1
|
110 |
+
python3 step_2_make_vocabulary.py \
|
111 |
+
--vocabulary_dir "${vocabulary_dir}" \
|
112 |
+
--train_dataset "${train_dataset}" \
|
113 |
+
--valid_dataset "${valid_dataset}" \
|
114 |
+
|
115 |
+
fi
|
116 |
+
|
117 |
+
|
118 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
119 |
+
$verbose && echo "stage 2: train model"
|
120 |
+
cd "${work_dir}" || exit 1
|
121 |
+
python3 step_3_train_model.py \
|
122 |
+
--vocabulary_dir "${vocabulary_dir}" \
|
123 |
+
--train_dataset "${train_dataset}" \
|
124 |
+
--valid_dataset "${valid_dataset}" \
|
125 |
+
--serialization_dir "${file_dir}" \
|
126 |
+
|
127 |
+
fi
|
128 |
+
|
129 |
+
|
130 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
131 |
+
$verbose && echo "stage 3: test model"
|
132 |
+
cd "${work_dir}" || exit 1
|
133 |
+
python3 step_4_evaluation_model.py \
|
134 |
+
--dataset "${dataset}" \
|
135 |
+
--vocabulary_dir "${vocabulary_dir}" \
|
136 |
+
--model_dir "${file_dir}/best" \
|
137 |
+
--output_file "${evaluation_file}" \
|
138 |
+
|
139 |
+
fi
|
140 |
+
|
141 |
+
|
142 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
143 |
+
$verbose && echo "stage 4: export model"
|
144 |
+
cd "${work_dir}" || exit 1
|
145 |
+
python3 step_5_export_models.py \
|
146 |
+
--vocabulary_dir "${vocabulary_dir}" \
|
147 |
+
--model_dir "${file_dir}/best" \
|
148 |
+
--serialization_dir "${file_dir}" \
|
149 |
+
|
150 |
+
fi
|
151 |
+
|
152 |
+
|
153 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
154 |
+
$verbose && echo "stage 5: collect files"
|
155 |
+
cd "${work_dir}" || exit 1
|
156 |
+
|
157 |
+
mkdir -p ${final_model_dir}
|
158 |
+
|
159 |
+
cp "${file_dir}/best"/* "${final_model_dir}"
|
160 |
+
cp -r "${file_dir}/vocabulary" "${final_model_dir}"
|
161 |
+
|
162 |
+
cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx"
|
163 |
+
|
164 |
+
cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip"
|
165 |
+
cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip"
|
166 |
+
cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip"
|
167 |
+
cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip"
|
168 |
+
|
169 |
+
cd "${final_model_dir}/.." || exit 1;
|
170 |
+
|
171 |
+
if [ -e "${final_model_name}.zip" ]; then
|
172 |
+
rm -rf "${final_model_name}_backup.zip"
|
173 |
+
mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
|
174 |
+
fi
|
175 |
+
|
176 |
+
zip -r "${final_model_name}.zip" "${final_model_name}"
|
177 |
+
rm -rf "${final_model_name}"
|
178 |
+
|
179 |
+
fi
|
180 |
+
|
181 |
+
|
182 |
+
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
183 |
+
$verbose && echo "stage 6: clear file_dir"
|
184 |
+
cd "${work_dir}" || exit 1
|
185 |
+
|
186 |
+
rm -rf "${file_dir}";
|
187 |
+
|
188 |
+
fi
|
examples/vm_sound_classification/step_1_prepare_data.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
from glob import glob
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
from pathlib import Path
|
8 |
+
import random
|
9 |
+
import sys
|
10 |
+
|
11 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
12 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
13 |
+
|
14 |
+
import pandas as pd
|
15 |
+
from scipy.io import wavfile
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
|
19 |
+
def get_args():
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument("--file_dir", default="./", type=str)
|
22 |
+
parser.add_argument("--task", default="default", type=str)
|
23 |
+
parser.add_argument("--filename_patterns", type=str)
|
24 |
+
|
25 |
+
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
26 |
+
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
27 |
+
|
28 |
+
args = parser.parse_args()
|
29 |
+
return args
|
30 |
+
|
31 |
+
|
32 |
+
def get_dataset(args):
|
33 |
+
filename_patterns = args.filename_patterns
|
34 |
+
filename_patterns = filename_patterns.split(" ")
|
35 |
+
print(filename_patterns)
|
36 |
+
|
37 |
+
file_dir = Path(args.file_dir)
|
38 |
+
file_dir.mkdir(exist_ok=True)
|
39 |
+
|
40 |
+
# label3_map = {
|
41 |
+
# "bell": "voicemail",
|
42 |
+
# "white_noise": "mute",
|
43 |
+
# "low_white_noise": "mute",
|
44 |
+
# "high_white_noise": "mute",
|
45 |
+
# # "music": "music",
|
46 |
+
# "mute": "mute",
|
47 |
+
# "noise": "voice_or_noise",
|
48 |
+
# "noise_mute": "voice_or_noise",
|
49 |
+
# "voice": "voice_or_noise",
|
50 |
+
# "voicemail": "voicemail",
|
51 |
+
# }
|
52 |
+
label8_map = {
|
53 |
+
"bell": "bell",
|
54 |
+
"white_noise": "white_noise",
|
55 |
+
"low_white_noise": "white_noise",
|
56 |
+
"high_white_noise": "white_noise",
|
57 |
+
"music": "music",
|
58 |
+
"mute": "mute",
|
59 |
+
"noise": "noise",
|
60 |
+
"noise_mute": "noise_mute",
|
61 |
+
"voice": "voice",
|
62 |
+
"voicemail": "voicemail",
|
63 |
+
}
|
64 |
+
|
65 |
+
result = list()
|
66 |
+
for filename_pattern in filename_patterns:
|
67 |
+
filename_list = glob(filename_pattern)
|
68 |
+
for filename in tqdm(filename_list):
|
69 |
+
filename = Path(filename)
|
70 |
+
sample_rate, signal = wavfile.read(filename.as_posix())
|
71 |
+
if len(signal) < sample_rate * 2:
|
72 |
+
continue
|
73 |
+
|
74 |
+
folder = filename.parts[-2]
|
75 |
+
country = filename.parts[-4]
|
76 |
+
|
77 |
+
if folder not in label8_map.keys():
|
78 |
+
continue
|
79 |
+
|
80 |
+
labels = label8_map[folder]
|
81 |
+
|
82 |
+
random1 = random.random()
|
83 |
+
random2 = random.random()
|
84 |
+
|
85 |
+
result.append({
|
86 |
+
"filename": filename,
|
87 |
+
"folder": folder,
|
88 |
+
"category": country,
|
89 |
+
"labels": labels,
|
90 |
+
"random1": random1,
|
91 |
+
"random2": random2,
|
92 |
+
"flag": "TRAIN" if random2 < 0.8 else "TEST",
|
93 |
+
})
|
94 |
+
|
95 |
+
df = pd.DataFrame(result)
|
96 |
+
pivot_table = pd.pivot_table(df, index=["labels"], values=["filename"], aggfunc="count")
|
97 |
+
print(pivot_table)
|
98 |
+
|
99 |
+
df = df.sort_values(by=["random1"], ascending=False)
|
100 |
+
df.to_excel(
|
101 |
+
file_dir / "dataset.xlsx",
|
102 |
+
index=False,
|
103 |
+
# encoding="utf_8_sig"
|
104 |
+
)
|
105 |
+
|
106 |
+
return
|
107 |
+
|
108 |
+
|
109 |
+
def split_dataset(args):
|
110 |
+
"""分割训练集, 测试集"""
|
111 |
+
file_dir = Path(args.file_dir)
|
112 |
+
file_dir.mkdir(exist_ok=True)
|
113 |
+
|
114 |
+
df = pd.read_excel(file_dir / "dataset.xlsx")
|
115 |
+
|
116 |
+
train = list()
|
117 |
+
test = list()
|
118 |
+
|
119 |
+
for i, row in df.iterrows():
|
120 |
+
flag = row["flag"]
|
121 |
+
if flag == "TRAIN":
|
122 |
+
train.append(row)
|
123 |
+
else:
|
124 |
+
test.append(row)
|
125 |
+
|
126 |
+
train = pd.DataFrame(train)
|
127 |
+
train.to_excel(
|
128 |
+
args.train_dataset,
|
129 |
+
index=False,
|
130 |
+
# encoding="utf_8_sig"
|
131 |
+
)
|
132 |
+
test = pd.DataFrame(test)
|
133 |
+
test.to_excel(
|
134 |
+
args.valid_dataset,
|
135 |
+
index=False,
|
136 |
+
# encoding="utf_8_sig"
|
137 |
+
)
|
138 |
+
|
139 |
+
return
|
140 |
+
|
141 |
+
|
142 |
+
def main():
|
143 |
+
args = get_args()
|
144 |
+
get_dataset(args)
|
145 |
+
split_dataset(args)
|
146 |
+
return
|
147 |
+
|
148 |
+
|
149 |
+
if __name__ == "__main__":
|
150 |
+
main()
|
examples/vm_sound_classification/step_2_make_vocabulary.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
import sys
|
7 |
+
|
8 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
9 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
10 |
+
|
11 |
+
import pandas as pd
|
12 |
+
|
13 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
14 |
+
|
15 |
+
|
16 |
+
def get_args():
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
|
19 |
+
|
20 |
+
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
21 |
+
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
22 |
+
|
23 |
+
args = parser.parse_args()
|
24 |
+
return args
|
25 |
+
|
26 |
+
|
27 |
+
def main():
|
28 |
+
args = get_args()
|
29 |
+
|
30 |
+
train_dataset = pd.read_excel(args.train_dataset)
|
31 |
+
valid_dataset = pd.read_excel(args.valid_dataset)
|
32 |
+
|
33 |
+
vocabulary = Vocabulary()
|
34 |
+
|
35 |
+
# train
|
36 |
+
for i, row in train_dataset.iterrows():
|
37 |
+
label = row["labels"]
|
38 |
+
vocabulary.add_token_to_namespace(label, namespace="labels")
|
39 |
+
|
40 |
+
# valid
|
41 |
+
for i, row in valid_dataset.iterrows():
|
42 |
+
label = row["labels"]
|
43 |
+
vocabulary.add_token_to_namespace(label, namespace="labels")
|
44 |
+
|
45 |
+
vocabulary.save_to_files(args.vocabulary_dir)
|
46 |
+
|
47 |
+
return
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
main()
|
examples/vm_sound_classification/step_3_train_model.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
from collections import defaultdict
|
5 |
+
import json
|
6 |
+
import logging
|
7 |
+
from logging.handlers import TimedRotatingFileHandler
|
8 |
+
import os
|
9 |
+
import platform
|
10 |
+
from pathlib import Path
|
11 |
+
import random
|
12 |
+
import sys
|
13 |
+
import shutil
|
14 |
+
from typing import List
|
15 |
+
|
16 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
17 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
from torch.utils.data.dataloader import DataLoader
|
22 |
+
from tqdm import tqdm
|
23 |
+
|
24 |
+
from toolbox.torch.modules.loss import FocalLoss, HingeLoss, HingeLinear
|
25 |
+
from toolbox.torch.training.metrics.categorical_accuracy import CategoricalAccuracy
|
26 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
27 |
+
from toolbox.torch.utils.data.dataset.wave_classifier_excel_dataset import WaveClassifierExcelDataset
|
28 |
+
from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel
|
29 |
+
from toolbox.torchaudio.models.cnn_audio_classifier.configuration_cnn_audio_classifier import CnnAudioClassifierConfig
|
30 |
+
|
31 |
+
|
32 |
+
def get_args():
|
33 |
+
parser = argparse.ArgumentParser()
|
34 |
+
parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
|
35 |
+
|
36 |
+
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
37 |
+
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
38 |
+
|
39 |
+
parser.add_argument("--max_epochs", default=100, type=int)
|
40 |
+
|
41 |
+
parser.add_argument("--batch_size", default=64, type=int)
|
42 |
+
parser.add_argument("--learning_rate", default=1e-3, type=float)
|
43 |
+
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
44 |
+
parser.add_argument("--patience", default=5, type=int)
|
45 |
+
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
46 |
+
parser.add_argument("--seed", default=0, type=int)
|
47 |
+
|
48 |
+
parser.add_argument("--config_file", default="conv2d_classifier.yaml", type=str)
|
49 |
+
|
50 |
+
args = parser.parse_args()
|
51 |
+
return args
|
52 |
+
|
53 |
+
|
54 |
+
def logging_config(file_dir: str):
|
55 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
56 |
+
|
57 |
+
logging.basicConfig(format=fmt,
|
58 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
59 |
+
level=logging.DEBUG)
|
60 |
+
file_handler = TimedRotatingFileHandler(
|
61 |
+
filename=os.path.join(file_dir, "main.log"),
|
62 |
+
encoding="utf-8",
|
63 |
+
when="D",
|
64 |
+
interval=1,
|
65 |
+
backupCount=7
|
66 |
+
)
|
67 |
+
file_handler.setLevel(logging.INFO)
|
68 |
+
file_handler.setFormatter(logging.Formatter(fmt))
|
69 |
+
logger = logging.getLogger(__name__)
|
70 |
+
logger.addHandler(file_handler)
|
71 |
+
|
72 |
+
return logger
|
73 |
+
|
74 |
+
|
75 |
+
class CollateFunction(object):
|
76 |
+
def __init__(self):
|
77 |
+
pass
|
78 |
+
|
79 |
+
def __call__(self, batch: List[dict]):
|
80 |
+
array_list = list()
|
81 |
+
label_list = list()
|
82 |
+
for sample in batch:
|
83 |
+
array = sample["waveform"]
|
84 |
+
label = sample["label"]
|
85 |
+
|
86 |
+
array_list.append(array)
|
87 |
+
label_list.append(label)
|
88 |
+
|
89 |
+
array_list = torch.stack(array_list)
|
90 |
+
label_list = torch.stack(label_list)
|
91 |
+
return array_list, label_list
|
92 |
+
|
93 |
+
|
94 |
+
collate_fn = CollateFunction()
|
95 |
+
|
96 |
+
|
97 |
+
def main():
|
98 |
+
args = get_args()
|
99 |
+
|
100 |
+
serialization_dir = Path(args.serialization_dir)
|
101 |
+
serialization_dir.mkdir(parents=True, exist_ok=True)
|
102 |
+
|
103 |
+
logger = logging_config(serialization_dir)
|
104 |
+
|
105 |
+
random.seed(args.seed)
|
106 |
+
np.random.seed(args.seed)
|
107 |
+
torch.manual_seed(args.seed)
|
108 |
+
logger.info("set seed: {}".format(args.seed))
|
109 |
+
|
110 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
111 |
+
n_gpu = torch.cuda.device_count()
|
112 |
+
logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
|
113 |
+
|
114 |
+
vocabulary = Vocabulary.from_files(args.vocabulary_dir)
|
115 |
+
|
116 |
+
# datasets
|
117 |
+
logger.info("prepare datasets")
|
118 |
+
train_dataset = WaveClassifierExcelDataset(
|
119 |
+
vocab=vocabulary,
|
120 |
+
excel_file=args.train_dataset,
|
121 |
+
category=None,
|
122 |
+
category_field="category",
|
123 |
+
label_field="labels",
|
124 |
+
expected_sample_rate=8000,
|
125 |
+
max_wave_value=32768.0,
|
126 |
+
)
|
127 |
+
valid_dataset = WaveClassifierExcelDataset(
|
128 |
+
vocab=vocabulary,
|
129 |
+
excel_file=args.valid_dataset,
|
130 |
+
category=None,
|
131 |
+
category_field="category",
|
132 |
+
label_field="labels",
|
133 |
+
expected_sample_rate=8000,
|
134 |
+
max_wave_value=32768.0,
|
135 |
+
)
|
136 |
+
train_data_loader = DataLoader(
|
137 |
+
dataset=train_dataset,
|
138 |
+
batch_size=args.batch_size,
|
139 |
+
shuffle=True,
|
140 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
141 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
142 |
+
collate_fn=collate_fn,
|
143 |
+
pin_memory=False,
|
144 |
+
# prefetch_factor=64,
|
145 |
+
)
|
146 |
+
valid_data_loader = DataLoader(
|
147 |
+
dataset=valid_dataset,
|
148 |
+
batch_size=args.batch_size,
|
149 |
+
shuffle=True,
|
150 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
151 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
152 |
+
collate_fn=collate_fn,
|
153 |
+
pin_memory=False,
|
154 |
+
# prefetch_factor=64,
|
155 |
+
)
|
156 |
+
|
157 |
+
# models
|
158 |
+
logger.info("prepare models")
|
159 |
+
config = CnnAudioClassifierConfig.from_pretrained(
|
160 |
+
pretrained_model_name_or_path=args.config_file,
|
161 |
+
# num_labels=vocabulary.get_vocab_size(namespace="labels")
|
162 |
+
)
|
163 |
+
if not config.cls_head_param["num_labels"] == vocabulary.get_vocab_size(namespace="labels"):
|
164 |
+
raise AssertionError
|
165 |
+
model = WaveClassifierPretrainedModel(
|
166 |
+
config=config,
|
167 |
+
)
|
168 |
+
model.to(device)
|
169 |
+
model.train()
|
170 |
+
|
171 |
+
# optimizer
|
172 |
+
logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
|
173 |
+
param_optimizer = model.parameters()
|
174 |
+
optimizer = torch.optim.Adam(
|
175 |
+
param_optimizer,
|
176 |
+
lr=args.learning_rate,
|
177 |
+
)
|
178 |
+
# lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
179 |
+
# optimizer,
|
180 |
+
# step_size=2000
|
181 |
+
# )
|
182 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
183 |
+
optimizer,
|
184 |
+
milestones=[10000, 20000, 30000], gamma=0.5
|
185 |
+
)
|
186 |
+
focal_loss = FocalLoss(
|
187 |
+
num_classes=vocabulary.get_vocab_size(namespace="labels"),
|
188 |
+
reduction="mean",
|
189 |
+
)
|
190 |
+
categorical_accuracy = CategoricalAccuracy()
|
191 |
+
|
192 |
+
# training loop
|
193 |
+
logger.info("training")
|
194 |
+
|
195 |
+
training_loss = 10000000000
|
196 |
+
training_accuracy = 0.
|
197 |
+
evaluation_loss = 10000000000
|
198 |
+
evaluation_accuracy = 0.
|
199 |
+
|
200 |
+
model_list = list()
|
201 |
+
best_idx_epoch = None
|
202 |
+
best_accuracy = None
|
203 |
+
patience_count = 0
|
204 |
+
|
205 |
+
for idx_epoch in range(args.max_epochs):
|
206 |
+
categorical_accuracy.reset()
|
207 |
+
total_loss = 0.
|
208 |
+
total_examples = 0.
|
209 |
+
progress_bar = tqdm(
|
210 |
+
total=len(train_data_loader),
|
211 |
+
desc="Training; epoch: {}".format(idx_epoch),
|
212 |
+
)
|
213 |
+
for batch in train_data_loader:
|
214 |
+
input_ids, label_ids = batch
|
215 |
+
input_ids = input_ids.to(device)
|
216 |
+
label_ids: torch.LongTensor = label_ids.to(device).long()
|
217 |
+
|
218 |
+
logits = model.forward(input_ids)
|
219 |
+
loss = focal_loss.forward(logits, label_ids.view(-1))
|
220 |
+
categorical_accuracy(logits, label_ids)
|
221 |
+
|
222 |
+
total_loss += loss.item()
|
223 |
+
total_examples += input_ids.size(0)
|
224 |
+
|
225 |
+
optimizer.zero_grad()
|
226 |
+
loss.backward()
|
227 |
+
optimizer.step()
|
228 |
+
lr_scheduler.step()
|
229 |
+
|
230 |
+
training_loss = total_loss / total_examples
|
231 |
+
training_loss = round(training_loss, 4)
|
232 |
+
training_accuracy = categorical_accuracy.get_metric()["accuracy"]
|
233 |
+
training_accuracy = round(training_accuracy, 4)
|
234 |
+
|
235 |
+
progress_bar.update(1)
|
236 |
+
progress_bar.set_postfix({
|
237 |
+
"training_loss": training_loss,
|
238 |
+
"training_accuracy": training_accuracy,
|
239 |
+
})
|
240 |
+
|
241 |
+
categorical_accuracy.reset()
|
242 |
+
total_loss = 0.
|
243 |
+
total_examples = 0.
|
244 |
+
progress_bar = tqdm(
|
245 |
+
total=len(valid_data_loader),
|
246 |
+
desc="Evaluation; epoch: {}".format(idx_epoch),
|
247 |
+
)
|
248 |
+
for batch in valid_data_loader:
|
249 |
+
input_ids, label_ids = batch
|
250 |
+
input_ids = input_ids.to(device)
|
251 |
+
label_ids: torch.LongTensor = label_ids.to(device).long()
|
252 |
+
|
253 |
+
with torch.no_grad():
|
254 |
+
logits = model.forward(input_ids)
|
255 |
+
loss = focal_loss.forward(logits, label_ids.view(-1))
|
256 |
+
categorical_accuracy(logits, label_ids)
|
257 |
+
|
258 |
+
total_loss += loss.item()
|
259 |
+
total_examples += input_ids.size(0)
|
260 |
+
|
261 |
+
evaluation_loss = total_loss / total_examples
|
262 |
+
evaluation_loss = round(evaluation_loss, 4)
|
263 |
+
evaluation_accuracy = categorical_accuracy.get_metric()["accuracy"]
|
264 |
+
evaluation_accuracy = round(evaluation_accuracy, 4)
|
265 |
+
|
266 |
+
progress_bar.update(1)
|
267 |
+
progress_bar.set_postfix({
|
268 |
+
"evaluation_loss": evaluation_loss,
|
269 |
+
"evaluation_accuracy": evaluation_accuracy,
|
270 |
+
})
|
271 |
+
|
272 |
+
# save path
|
273 |
+
epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
|
274 |
+
epoch_dir.mkdir(parents=True, exist_ok=False)
|
275 |
+
|
276 |
+
# save models
|
277 |
+
model.save_pretrained(epoch_dir.as_posix())
|
278 |
+
|
279 |
+
model_list.append(epoch_dir)
|
280 |
+
if len(model_list) >= args.num_serialized_models_to_keep:
|
281 |
+
model_to_delete: Path = model_list.pop(0)
|
282 |
+
shutil.rmtree(model_to_delete.as_posix())
|
283 |
+
|
284 |
+
# save metric
|
285 |
+
if best_accuracy is None:
|
286 |
+
best_idx_epoch = idx_epoch
|
287 |
+
best_accuracy = evaluation_accuracy
|
288 |
+
elif evaluation_accuracy > best_accuracy:
|
289 |
+
best_idx_epoch = idx_epoch
|
290 |
+
best_accuracy = evaluation_accuracy
|
291 |
+
else:
|
292 |
+
pass
|
293 |
+
|
294 |
+
metrics = {
|
295 |
+
"idx_epoch": idx_epoch,
|
296 |
+
"best_idx_epoch": best_idx_epoch,
|
297 |
+
"best_accuracy": best_accuracy,
|
298 |
+
"training_loss": training_loss,
|
299 |
+
"training_accuracy": training_accuracy,
|
300 |
+
"evaluation_loss": evaluation_loss,
|
301 |
+
"evaluation_accuracy": evaluation_accuracy,
|
302 |
+
"learning_rate": optimizer.param_groups[0]['lr'],
|
303 |
+
}
|
304 |
+
metrics_filename = epoch_dir / "metrics_epoch.json"
|
305 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
306 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
307 |
+
|
308 |
+
# save best
|
309 |
+
best_dir = serialization_dir / "best"
|
310 |
+
if best_idx_epoch == idx_epoch:
|
311 |
+
if best_dir.exists():
|
312 |
+
shutil.rmtree(best_dir)
|
313 |
+
shutil.copytree(epoch_dir, best_dir)
|
314 |
+
|
315 |
+
# early stop
|
316 |
+
early_stop_flag = False
|
317 |
+
if best_idx_epoch == idx_epoch:
|
318 |
+
patience_count = 0
|
319 |
+
else:
|
320 |
+
patience_count += 1
|
321 |
+
if patience_count >= args.patience:
|
322 |
+
early_stop_flag = True
|
323 |
+
|
324 |
+
# early stop
|
325 |
+
if early_stop_flag:
|
326 |
+
break
|
327 |
+
return
|
328 |
+
|
329 |
+
|
330 |
+
if __name__ == "__main__":
|
331 |
+
main()
|
examples/vm_sound_classification/step_4_evaluation_model.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
from collections import defaultdict
|
5 |
+
import json
|
6 |
+
import logging
|
7 |
+
from logging.handlers import TimedRotatingFileHandler
|
8 |
+
import os
|
9 |
+
import platform
|
10 |
+
from pathlib import Path
|
11 |
+
import sys
|
12 |
+
import shutil
|
13 |
+
from typing import List
|
14 |
+
|
15 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
16 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
17 |
+
|
18 |
+
import pandas as pd
|
19 |
+
from scipy.io import wavfile
|
20 |
+
import torch
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
24 |
+
from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel
|
25 |
+
|
26 |
+
|
27 |
+
def get_args():
|
28 |
+
parser = argparse.ArgumentParser()
|
29 |
+
parser.add_argument("--dataset", default="dataset.xlsx", type=str)
|
30 |
+
|
31 |
+
parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
|
32 |
+
parser.add_argument("--model_dir", default="best", type=str)
|
33 |
+
|
34 |
+
parser.add_argument("--output_file", default="evaluation.xlsx", type=str)
|
35 |
+
|
36 |
+
args = parser.parse_args()
|
37 |
+
return args
|
38 |
+
|
39 |
+
|
40 |
+
def logging_config():
|
41 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
42 |
+
|
43 |
+
logging.basicConfig(format=fmt,
|
44 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
45 |
+
level=logging.DEBUG)
|
46 |
+
stream_handler = logging.StreamHandler()
|
47 |
+
stream_handler.setLevel(logging.INFO)
|
48 |
+
stream_handler.setFormatter(logging.Formatter(fmt))
|
49 |
+
|
50 |
+
logger = logging.getLogger(__name__)
|
51 |
+
|
52 |
+
return logger
|
53 |
+
|
54 |
+
|
55 |
+
def main():
|
56 |
+
args = get_args()
|
57 |
+
|
58 |
+
logger = logging_config()
|
59 |
+
|
60 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
61 |
+
n_gpu = torch.cuda.device_count()
|
62 |
+
logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
|
63 |
+
|
64 |
+
logger.info("prepare vocabulary, model")
|
65 |
+
vocabulary = Vocabulary.from_files(args.vocabulary_dir)
|
66 |
+
|
67 |
+
model = WaveClassifierPretrainedModel.from_pretrained(
|
68 |
+
pretrained_model_name_or_path=args.model_dir,
|
69 |
+
)
|
70 |
+
model.to(device)
|
71 |
+
model.eval()
|
72 |
+
|
73 |
+
logger.info("read excel")
|
74 |
+
df = pd.read_excel(args.dataset)
|
75 |
+
result = list()
|
76 |
+
|
77 |
+
total_correct = 0
|
78 |
+
total_examples = 0
|
79 |
+
|
80 |
+
progress_bar = tqdm(total=len(df), desc="Evaluation")
|
81 |
+
for i, row in df.iterrows():
|
82 |
+
filename = row["filename"]
|
83 |
+
ground_true = row["labels"]
|
84 |
+
|
85 |
+
sample_rate, waveform = wavfile.read(filename)
|
86 |
+
waveform = waveform / (1 << 15)
|
87 |
+
waveform = torch.tensor(waveform, dtype=torch.float32)
|
88 |
+
waveform = torch.unsqueeze(waveform, dim=0)
|
89 |
+
waveform = waveform.to(device)
|
90 |
+
|
91 |
+
with torch.no_grad():
|
92 |
+
logits = model.forward(waveform)
|
93 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
94 |
+
label_idx = torch.argmax(probs, dim=-1)
|
95 |
+
|
96 |
+
label_idx = label_idx.cpu()
|
97 |
+
probs = probs.cpu()
|
98 |
+
|
99 |
+
label_idx = label_idx.numpy()[0]
|
100 |
+
label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
|
101 |
+
prob = probs[0][label_idx].numpy()
|
102 |
+
|
103 |
+
correct = 1 if label_str == ground_true else 0
|
104 |
+
row_ = dict(row)
|
105 |
+
row_["predict"] = label_str
|
106 |
+
row_["prob"] = prob
|
107 |
+
row_["correct"] = correct
|
108 |
+
result.append(row_)
|
109 |
+
|
110 |
+
total_examples += 1
|
111 |
+
total_correct += correct
|
112 |
+
accuracy = total_correct / total_examples
|
113 |
+
|
114 |
+
progress_bar.update(1)
|
115 |
+
progress_bar.set_postfix({
|
116 |
+
"accuracy": accuracy,
|
117 |
+
})
|
118 |
+
|
119 |
+
result = pd.DataFrame(result)
|
120 |
+
result.to_excel(
|
121 |
+
args.output_file,
|
122 |
+
index=False
|
123 |
+
)
|
124 |
+
return
|
125 |
+
|
126 |
+
|
127 |
+
if __name__ == '__main__':
|
128 |
+
main()
|
examples/vm_sound_classification/step_5_export_models.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
from collections import defaultdict
|
5 |
+
import json
|
6 |
+
import logging
|
7 |
+
from logging.handlers import TimedRotatingFileHandler
|
8 |
+
import os
|
9 |
+
import platform
|
10 |
+
from pathlib import Path
|
11 |
+
import sys
|
12 |
+
import shutil
|
13 |
+
from typing import List
|
14 |
+
|
15 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
16 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
22 |
+
from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel
|
23 |
+
|
24 |
+
|
25 |
+
def get_args():
|
26 |
+
parser = argparse.ArgumentParser()
|
27 |
+
parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
|
28 |
+
parser.add_argument("--model_dir", default="best", type=str)
|
29 |
+
|
30 |
+
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
31 |
+
|
32 |
+
args = parser.parse_args()
|
33 |
+
return args
|
34 |
+
|
35 |
+
|
36 |
+
def logging_config():
|
37 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
38 |
+
|
39 |
+
logging.basicConfig(format=fmt,
|
40 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
41 |
+
level=logging.DEBUG)
|
42 |
+
stream_handler = logging.StreamHandler()
|
43 |
+
stream_handler.setLevel(logging.INFO)
|
44 |
+
stream_handler.setFormatter(logging.Formatter(fmt))
|
45 |
+
|
46 |
+
logger = logging.getLogger(__name__)
|
47 |
+
|
48 |
+
return logger
|
49 |
+
|
50 |
+
|
51 |
+
def main():
|
52 |
+
args = get_args()
|
53 |
+
|
54 |
+
serialization_dir = Path(args.serialization_dir)
|
55 |
+
|
56 |
+
logger = logging_config()
|
57 |
+
|
58 |
+
logger.info("export models on CPU")
|
59 |
+
device = torch.device("cpu")
|
60 |
+
|
61 |
+
logger.info("prepare vocabulary, model")
|
62 |
+
vocabulary = Vocabulary.from_files(args.vocabulary_dir)
|
63 |
+
|
64 |
+
model = WaveClassifierPretrainedModel.from_pretrained(
|
65 |
+
pretrained_model_name_or_path=args.model_dir,
|
66 |
+
num_labels=vocabulary.get_vocab_size(namespace="labels")
|
67 |
+
)
|
68 |
+
model.to(device)
|
69 |
+
model.eval()
|
70 |
+
|
71 |
+
waveform = 0 + 25 * np.random.randn(16000,)
|
72 |
+
waveform = np.array(waveform, dtype=np.int16)
|
73 |
+
waveform = waveform / (1 << 15)
|
74 |
+
waveform = torch.tensor(waveform, dtype=torch.float32)
|
75 |
+
waveform = torch.unsqueeze(waveform, dim=0)
|
76 |
+
waveform = waveform.to(device)
|
77 |
+
|
78 |
+
logger.info("export jit models")
|
79 |
+
example_inputs = (waveform,)
|
80 |
+
|
81 |
+
# trace model
|
82 |
+
trace_model = torch.jit.trace(func=model, example_inputs=example_inputs, strict=False)
|
83 |
+
trace_model.save(serialization_dir / "trace_model.zip")
|
84 |
+
|
85 |
+
# quantization trace model (not work on GPU)
|
86 |
+
quantized_model = torch.quantization.quantize_dynamic(
|
87 |
+
model, {torch.nn.Linear}, dtype=torch.qint8
|
88 |
+
)
|
89 |
+
trace_quant_model = torch.jit.trace(func=quantized_model, example_inputs=example_inputs, strict=False)
|
90 |
+
trace_quant_model.save(serialization_dir / "trace_quant_model.zip")
|
91 |
+
|
92 |
+
# script model
|
93 |
+
script_model = torch.jit.script(obj=model)
|
94 |
+
script_model.save(serialization_dir / "script_model.zip")
|
95 |
+
|
96 |
+
# quantization script model (not work on GPU)
|
97 |
+
quantized_model = torch.quantization.quantize_dynamic(
|
98 |
+
model, {torch.nn.Linear}, dtype=torch.qint8
|
99 |
+
)
|
100 |
+
script_quant_model = torch.jit.script(quantized_model)
|
101 |
+
script_quant_model.save(serialization_dir / "script_quant_model.zip")
|
102 |
+
return
|
103 |
+
|
104 |
+
|
105 |
+
if __name__ == '__main__':
|
106 |
+
main()
|
examples/vm_sound_classification/step_6_infer.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
import shutil
|
7 |
+
import sys
|
8 |
+
import tempfile
|
9 |
+
import zipfile
|
10 |
+
|
11 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
12 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
13 |
+
|
14 |
+
from scipy.io import wavfile
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from project_settings import project_path
|
18 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
19 |
+
|
20 |
+
|
21 |
+
def get_args():
|
22 |
+
parser = argparse.ArgumentParser()
|
23 |
+
parser.add_argument(
|
24 |
+
"--model_file",
|
25 |
+
default=(project_path / "trained_models/vm_sound_classification3.zip").as_posix(),
|
26 |
+
type=str
|
27 |
+
)
|
28 |
+
parser.add_argument(
|
29 |
+
"--wav_file",
|
30 |
+
default=r"C:\Users\tianx\Desktop\a073d03d-d280-46df-9b2d-d904965f4500_zh-CN_h3f25ivhb0c0_1719478037746.wav",
|
31 |
+
type=str
|
32 |
+
)
|
33 |
+
|
34 |
+
parser.add_argument("--device", default="cpu", type=str)
|
35 |
+
|
36 |
+
args = parser.parse_args()
|
37 |
+
return args
|
38 |
+
|
39 |
+
|
40 |
+
def main():
|
41 |
+
args = get_args()
|
42 |
+
|
43 |
+
model_file = Path(args.model_file)
|
44 |
+
|
45 |
+
device = torch.device(args.device)
|
46 |
+
|
47 |
+
with zipfile.ZipFile(model_file, "r") as f_zip:
|
48 |
+
out_root = Path(tempfile.gettempdir()) / "vm_sound_classification"
|
49 |
+
print(out_root.as_posix())
|
50 |
+
if out_root.exists():
|
51 |
+
shutil.rmtree(out_root.as_posix())
|
52 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
53 |
+
f_zip.extractall(path=out_root)
|
54 |
+
|
55 |
+
tgt_path = out_root / model_file.stem
|
56 |
+
jit_model_file = tgt_path / "trace_model.zip"
|
57 |
+
vocab_path = tgt_path / "vocabulary"
|
58 |
+
|
59 |
+
with open(jit_model_file.as_posix(), "rb") as f:
|
60 |
+
model = torch.jit.load(f)
|
61 |
+
model.to(device)
|
62 |
+
model.eval()
|
63 |
+
vocabulary = Vocabulary.from_files(vocab_path.as_posix())
|
64 |
+
|
65 |
+
# infer
|
66 |
+
sample_rate, waveform = wavfile.read(args.wav_file)
|
67 |
+
waveform = waveform[:16000]
|
68 |
+
waveform = waveform / (1 << 15)
|
69 |
+
waveform = torch.tensor(waveform, dtype=torch.float32)
|
70 |
+
waveform = torch.unsqueeze(waveform, dim=0)
|
71 |
+
waveform = waveform.to(device)
|
72 |
+
|
73 |
+
with torch.no_grad():
|
74 |
+
logits = model.forward(waveform)
|
75 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
76 |
+
label_idx = torch.argmax(probs, dim=-1)
|
77 |
+
|
78 |
+
label_idx = label_idx.cpu()
|
79 |
+
probs = probs.cpu()
|
80 |
+
|
81 |
+
label_idx = label_idx.numpy()[0]
|
82 |
+
prob = probs.numpy()[0][label_idx]
|
83 |
+
|
84 |
+
label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
|
85 |
+
print(label_str)
|
86 |
+
print(prob)
|
87 |
+
return
|
88 |
+
|
89 |
+
|
90 |
+
if __name__ == '__main__':
|
91 |
+
main()
|
examples/vm_sound_classification/step_7_test_model.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
import shutil
|
7 |
+
import sys
|
8 |
+
import tempfile
|
9 |
+
import zipfile
|
10 |
+
|
11 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
12 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
13 |
+
|
14 |
+
from scipy.io import wavfile
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from project_settings import project_path
|
18 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
19 |
+
from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel
|
20 |
+
|
21 |
+
|
22 |
+
def get_args():
|
23 |
+
parser = argparse.ArgumentParser()
|
24 |
+
parser.add_argument(
|
25 |
+
"--model_file",
|
26 |
+
default=(project_path / "trained_models/vm_sound_classification3.zip").as_posix(),
|
27 |
+
type=str
|
28 |
+
)
|
29 |
+
parser.add_argument(
|
30 |
+
"--wav_file",
|
31 |
+
default=r"C:\Users\tianx\Desktop\4b284733-0be3-4a48-abbb-615b32ac44b7_6ndddc2szlh0.wav",
|
32 |
+
type=str
|
33 |
+
)
|
34 |
+
|
35 |
+
parser.add_argument("--device", default="cpu", type=str)
|
36 |
+
|
37 |
+
args = parser.parse_args()
|
38 |
+
return args
|
39 |
+
|
40 |
+
|
41 |
+
def main():
|
42 |
+
args = get_args()
|
43 |
+
|
44 |
+
model_file = Path(args.model_file)
|
45 |
+
|
46 |
+
device = torch.device(args.device)
|
47 |
+
|
48 |
+
with zipfile.ZipFile(model_file, "r") as f_zip:
|
49 |
+
out_root = Path(tempfile.gettempdir()) / "vm_sound_classification"
|
50 |
+
print(out_root)
|
51 |
+
if out_root.exists():
|
52 |
+
shutil.rmtree(out_root.as_posix())
|
53 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
54 |
+
f_zip.extractall(path=out_root)
|
55 |
+
|
56 |
+
tgt_path = out_root / model_file.stem
|
57 |
+
vocab_path = tgt_path / "vocabulary"
|
58 |
+
|
59 |
+
vocabulary = Vocabulary.from_files(vocab_path.as_posix())
|
60 |
+
|
61 |
+
model = WaveClassifierPretrainedModel.from_pretrained(
|
62 |
+
pretrained_model_name_or_path=tgt_path.as_posix(),
|
63 |
+
)
|
64 |
+
model.to(device)
|
65 |
+
model.eval()
|
66 |
+
|
67 |
+
# infer
|
68 |
+
sample_rate, waveform = wavfile.read(args.wav_file)
|
69 |
+
waveform = waveform[:16000]
|
70 |
+
waveform = waveform / (1 << 15)
|
71 |
+
waveform = torch.tensor(waveform, dtype=torch.float32)
|
72 |
+
waveform = torch.unsqueeze(waveform, dim=0)
|
73 |
+
waveform = waveform.to(device)
|
74 |
+
print(waveform.shape)
|
75 |
+
with torch.no_grad():
|
76 |
+
logits = model.forward(waveform)
|
77 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
78 |
+
label_idx = torch.argmax(probs, dim=-1)
|
79 |
+
|
80 |
+
label_idx = label_idx.cpu()
|
81 |
+
probs = probs.cpu()
|
82 |
+
|
83 |
+
label_idx = label_idx.numpy()[0]
|
84 |
+
prob = probs.numpy()[0][label_idx]
|
85 |
+
|
86 |
+
label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
|
87 |
+
print(label_str)
|
88 |
+
print(prob)
|
89 |
+
return
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == '__main__':
|
93 |
+
main()
|
examples/vm_sound_classification/stop.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
kill -9 `ps -aef | grep 'vm_sound_classification/bin/python3' | grep -v grep | awk '{print $2}' | sed 's/\n/ /'`
|
examples/vm_sound_classification8/requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.10.1
|
2 |
+
torchaudio==0.10.1
|
3 |
+
fsspec==2022.1.0
|
4 |
+
librosa==0.9.2
|
5 |
+
pandas==1.1.5
|
6 |
+
openpyxl==3.0.9
|
7 |
+
xlrd==1.2.0
|
8 |
+
tqdm==4.64.1
|
9 |
+
overrides==1.9.0
|
examples/vm_sound_classification8/run.sh
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
: <<'END'
|
4 |
+
|
5 |
+
sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification8 \
|
6 |
+
--filename_patterns "E:/programmer/asr_datasets/voicemail/wav_finished/en-US/wav_finished/*/*.wav \
|
7 |
+
E:/programmer/asr_datasets/voicemail/wav_finished/id-ID/wav_finished/*/*.wav" \
|
8 |
+
|
9 |
+
|
10 |
+
sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification8 \
|
11 |
+
--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
|
12 |
+
|
13 |
+
sh run.sh --stage 4 --stop_stage 4 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification8 \
|
14 |
+
--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
|
15 |
+
|
16 |
+
sh run.sh --stage 4 --stop_stage 4 --system_version centos --file_folder_name file_dir --final_model_name vm_sound_classification8 \
|
17 |
+
--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
|
18 |
+
|
19 |
+
|
20 |
+
"
|
21 |
+
|
22 |
+
END
|
23 |
+
|
24 |
+
|
25 |
+
# sh run.sh --stage -1 --stop_stage 9
|
26 |
+
# sh run.sh --stage -1 --stop_stage 5 --system_version centos --file_folder_name task_cnn_voicemail_id_id --final_model_name cnn_voicemail_id_id
|
27 |
+
# sh run.sh --stage 3 --stop_stage 4
|
28 |
+
# sh run.sh --stage 4 --stop_stage 4
|
29 |
+
# sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name task_cnn_voicemail_id_id
|
30 |
+
|
31 |
+
# params
|
32 |
+
system_version="windows";
|
33 |
+
verbose=true;
|
34 |
+
stage=0 # start from 0 if you need to start from data preparation
|
35 |
+
stop_stage=9
|
36 |
+
|
37 |
+
work_dir="$(pwd)"
|
38 |
+
file_folder_name=file_folder_name
|
39 |
+
final_model_name=final_model_name
|
40 |
+
filename_patterns="/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
|
41 |
+
nohup_name=nohup.out
|
42 |
+
|
43 |
+
country=en-US
|
44 |
+
|
45 |
+
# model params
|
46 |
+
batch_size=64
|
47 |
+
max_epochs=200
|
48 |
+
save_top_k=10
|
49 |
+
patience=5
|
50 |
+
|
51 |
+
|
52 |
+
# parse options
|
53 |
+
while true; do
|
54 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
55 |
+
case "$1" in
|
56 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
57 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
58 |
+
old_value="(eval echo \\$$name)";
|
59 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
60 |
+
was_bool=true;
|
61 |
+
else
|
62 |
+
was_bool=false;
|
63 |
+
fi
|
64 |
+
|
65 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
66 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
67 |
+
eval "${name}=\"$2\"";
|
68 |
+
|
69 |
+
# Check that Boolean-valued arguments are really Boolean.
|
70 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
71 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
72 |
+
exit 1;
|
73 |
+
fi
|
74 |
+
shift 2;
|
75 |
+
;;
|
76 |
+
|
77 |
+
*) break;
|
78 |
+
esac
|
79 |
+
done
|
80 |
+
|
81 |
+
file_dir="${work_dir}/${file_folder_name}"
|
82 |
+
final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
|
83 |
+
|
84 |
+
train_dataset="${file_dir}/train.xlsx"
|
85 |
+
valid_dataset="${file_dir}/valid.xlsx"
|
86 |
+
vocabulary_dir="${file_dir}/vocabulary"
|
87 |
+
|
88 |
+
|
89 |
+
$verbose && echo "system_version: ${system_version}"
|
90 |
+
$verbose && echo "file_folder_name: ${file_folder_name}"
|
91 |
+
|
92 |
+
if [ $system_version == "windows" ]; then
|
93 |
+
alias python3='D:/Users/tianx/PycharmProjects/virtualenv/vm_sound_classification/Scripts/python.exe'
|
94 |
+
elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
|
95 |
+
#source /data/local/bin/vm_sound_classification/bin/activate
|
96 |
+
alias python3='/data/local/bin/vm_sound_classification/bin/python3'
|
97 |
+
fi
|
98 |
+
|
99 |
+
|
100 |
+
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
101 |
+
$verbose && echo "stage 0: prepare data"
|
102 |
+
cd "${work_dir}" || exit 1
|
103 |
+
python3 step_1_prepare_data.py \
|
104 |
+
--file_dir "${file_dir}" \
|
105 |
+
--filename_patterns "${filename_patterns}" \
|
106 |
+
--train_dataset "${train_dataset}" \
|
107 |
+
--valid_dataset "${valid_dataset}" \
|
108 |
+
|
109 |
+
fi
|
110 |
+
|
111 |
+
|
112 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
113 |
+
$verbose && echo "stage 1: make vocabulary"
|
114 |
+
cd "${work_dir}" || exit 1
|
115 |
+
python3 step_2_make_vocabulary.py \
|
116 |
+
--vocabulary_dir "${vocabulary_dir}" \
|
117 |
+
--train_dataset "${train_dataset}" \
|
118 |
+
--valid_dataset "${valid_dataset}" \
|
119 |
+
|
120 |
+
fi
|
121 |
+
|
122 |
+
|
123 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
124 |
+
$verbose && echo "stage 2: train global model"
|
125 |
+
cd "${work_dir}" || exit 1
|
126 |
+
python3 step_3_train_global_model.py \
|
127 |
+
--vocabulary_dir "${vocabulary_dir}" \
|
128 |
+
--train_dataset "${train_dataset}" \
|
129 |
+
--valid_dataset "${valid_dataset}" \
|
130 |
+
--serialization_dir "${file_dir}/global_model" \
|
131 |
+
|
132 |
+
fi
|
133 |
+
|
134 |
+
|
135 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
136 |
+
$verbose && echo "stage 3: train country model"
|
137 |
+
cd "${work_dir}" || exit 1
|
138 |
+
python3 step_4_train_country_model.py \
|
139 |
+
--vocabulary_dir "${vocabulary_dir}" \
|
140 |
+
--train_dataset "${train_dataset}" \
|
141 |
+
--valid_dataset "${valid_dataset}" \
|
142 |
+
--country "${country}" \
|
143 |
+
--serialization_dir "${file_dir}/country_model" \
|
144 |
+
|
145 |
+
fi
|
146 |
+
|
147 |
+
|
148 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
149 |
+
$verbose && echo "stage 4: train union model"
|
150 |
+
cd "${work_dir}" || exit 1
|
151 |
+
python3 step_5_train_union.py \
|
152 |
+
--vocabulary_dir "${vocabulary_dir}" \
|
153 |
+
--train_dataset "${train_dataset}" \
|
154 |
+
--valid_dataset "${valid_dataset}" \
|
155 |
+
--serialization_dir "${file_dir}/union" \
|
156 |
+
|
157 |
+
fi
|
examples/vm_sound_classification8/step_1_prepare_data.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
from glob import glob
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
from pathlib import Path
|
8 |
+
import random
|
9 |
+
import sys
|
10 |
+
|
11 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
12 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
13 |
+
|
14 |
+
import pandas as pd
|
15 |
+
from scipy.io import wavfile
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
|
19 |
+
def get_args():
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument("--file_dir", default="./", type=str)
|
22 |
+
parser.add_argument("--task", default="default", type=str)
|
23 |
+
parser.add_argument("--filename_patterns", type=str)
|
24 |
+
|
25 |
+
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
26 |
+
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
27 |
+
|
28 |
+
args = parser.parse_args()
|
29 |
+
return args
|
30 |
+
|
31 |
+
|
32 |
+
def get_dataset(args):
|
33 |
+
filename_patterns = args.filename_patterns
|
34 |
+
filename_patterns = filename_patterns.split(" ")
|
35 |
+
print(filename_patterns)
|
36 |
+
|
37 |
+
file_dir = Path(args.file_dir)
|
38 |
+
file_dir.mkdir(exist_ok=True)
|
39 |
+
|
40 |
+
global_label_map = {
|
41 |
+
"bell": "bell",
|
42 |
+
"white_noise": "white_noise",
|
43 |
+
"low_white_noise": "white_noise",
|
44 |
+
"high_white_noise": "noise",
|
45 |
+
"music": "music",
|
46 |
+
"mute": "mute",
|
47 |
+
"noise": "noise",
|
48 |
+
"noise_mute": "noise_mute",
|
49 |
+
"voice": "voice",
|
50 |
+
"voicemail": "voicemail",
|
51 |
+
}
|
52 |
+
|
53 |
+
country_label_map = {
|
54 |
+
"bell": "voicemail",
|
55 |
+
"white_noise": "non_voicemail",
|
56 |
+
"low_white_noise": "non_voicemail",
|
57 |
+
"hight_white_noise": "non_voicemail",
|
58 |
+
"music": "non_voicemail",
|
59 |
+
"mute": "non_voicemail",
|
60 |
+
"noise": "non_voicemail",
|
61 |
+
"noise_mute": "non_voicemail",
|
62 |
+
"voice": "non_voicemail",
|
63 |
+
"voicemail": "voicemail",
|
64 |
+
"non_voicemail": "non_voicemail",
|
65 |
+
}
|
66 |
+
|
67 |
+
result = list()
|
68 |
+
for filename_pattern in filename_patterns:
|
69 |
+
filename_list = glob(filename_pattern)
|
70 |
+
for filename in tqdm(filename_list):
|
71 |
+
filename = Path(filename)
|
72 |
+
sample_rate, signal = wavfile.read(filename.as_posix())
|
73 |
+
if len(signal) < sample_rate * 2:
|
74 |
+
continue
|
75 |
+
|
76 |
+
folder = filename.parts[-2]
|
77 |
+
country = filename.parts[-4]
|
78 |
+
|
79 |
+
if folder not in global_label_map.keys():
|
80 |
+
continue
|
81 |
+
if folder not in country_label_map.keys():
|
82 |
+
continue
|
83 |
+
|
84 |
+
global_label = global_label_map[folder]
|
85 |
+
country_label = country_label_map[folder]
|
86 |
+
|
87 |
+
random1 = random.random()
|
88 |
+
random2 = random.random()
|
89 |
+
|
90 |
+
result.append({
|
91 |
+
"filename": filename,
|
92 |
+
"folder": folder,
|
93 |
+
"category": country,
|
94 |
+
"global_labels": global_label,
|
95 |
+
"country_labels": country_label,
|
96 |
+
"random1": random1,
|
97 |
+
"random2": random2,
|
98 |
+
"flag": "TRAIN" if random2 < 0.8 else "TEST",
|
99 |
+
})
|
100 |
+
|
101 |
+
df = pd.DataFrame(result)
|
102 |
+
pivot_table = pd.pivot_table(df, index=["global_labels"], values=["filename"], aggfunc="count")
|
103 |
+
print(pivot_table)
|
104 |
+
|
105 |
+
df = df.sort_values(by=["random1"], ascending=False)
|
106 |
+
df.to_excel(
|
107 |
+
file_dir / "dataset.xlsx",
|
108 |
+
index=False,
|
109 |
+
# encoding="utf_8_sig"
|
110 |
+
)
|
111 |
+
|
112 |
+
return
|
113 |
+
|
114 |
+
|
115 |
+
def split_dataset(args):
|
116 |
+
"""分割训练集, 测试集"""
|
117 |
+
file_dir = Path(args.file_dir)
|
118 |
+
file_dir.mkdir(exist_ok=True)
|
119 |
+
|
120 |
+
df = pd.read_excel(file_dir / "dataset.xlsx")
|
121 |
+
|
122 |
+
train = list()
|
123 |
+
test = list()
|
124 |
+
|
125 |
+
for i, row in df.iterrows():
|
126 |
+
flag = row["flag"]
|
127 |
+
if flag == "TRAIN":
|
128 |
+
train.append(row)
|
129 |
+
else:
|
130 |
+
test.append(row)
|
131 |
+
|
132 |
+
train = pd.DataFrame(train)
|
133 |
+
train.to_excel(
|
134 |
+
args.train_dataset,
|
135 |
+
index=False,
|
136 |
+
# encoding="utf_8_sig"
|
137 |
+
)
|
138 |
+
test = pd.DataFrame(test)
|
139 |
+
test.to_excel(
|
140 |
+
args.valid_dataset,
|
141 |
+
index=False,
|
142 |
+
# encoding="utf_8_sig"
|
143 |
+
)
|
144 |
+
|
145 |
+
return
|
146 |
+
|
147 |
+
|
148 |
+
def main():
|
149 |
+
args = get_args()
|
150 |
+
get_dataset(args)
|
151 |
+
split_dataset(args)
|
152 |
+
return
|
153 |
+
|
154 |
+
|
155 |
+
if __name__ == "__main__":
|
156 |
+
main()
|
examples/vm_sound_classification8/step_2_make_vocabulary.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
import sys
|
7 |
+
|
8 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
9 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
10 |
+
|
11 |
+
import pandas as pd
|
12 |
+
|
13 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
14 |
+
|
15 |
+
|
16 |
+
def get_args():
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
|
19 |
+
|
20 |
+
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
21 |
+
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
22 |
+
|
23 |
+
args = parser.parse_args()
|
24 |
+
return args
|
25 |
+
|
26 |
+
|
27 |
+
def main():
|
28 |
+
args = get_args()
|
29 |
+
|
30 |
+
train_dataset = pd.read_excel(args.train_dataset)
|
31 |
+
valid_dataset = pd.read_excel(args.valid_dataset)
|
32 |
+
|
33 |
+
# non_padded_namespaces
|
34 |
+
category_set = set()
|
35 |
+
for i, row in train_dataset.iterrows():
|
36 |
+
category = row["category"]
|
37 |
+
category_set.add(category)
|
38 |
+
|
39 |
+
for i, row in valid_dataset.iterrows():
|
40 |
+
category = row["category"]
|
41 |
+
category_set.add(category)
|
42 |
+
|
43 |
+
vocabulary = Vocabulary(non_padded_namespaces=["global_labels", *list(category_set)])
|
44 |
+
|
45 |
+
# train
|
46 |
+
for i, row in train_dataset.iterrows():
|
47 |
+
global_labels = row["global_labels"]
|
48 |
+
country_labels = row["country_labels"]
|
49 |
+
category = row["category"]
|
50 |
+
|
51 |
+
vocabulary.add_token_to_namespace(global_labels, "global_labels")
|
52 |
+
vocabulary.add_token_to_namespace(country_labels, category)
|
53 |
+
|
54 |
+
# valid
|
55 |
+
for i, row in valid_dataset.iterrows():
|
56 |
+
global_labels = row["global_labels"]
|
57 |
+
country_labels = row["country_labels"]
|
58 |
+
category = row["category"]
|
59 |
+
|
60 |
+
vocabulary.add_token_to_namespace(global_labels, "global_labels")
|
61 |
+
vocabulary.add_token_to_namespace(country_labels, category)
|
62 |
+
|
63 |
+
vocabulary.save_to_files(args.vocabulary_dir)
|
64 |
+
|
65 |
+
return
|
66 |
+
|
67 |
+
|
68 |
+
if __name__ == "__main__":
|
69 |
+
main()
|
examples/vm_sound_classification8/step_3_train_global_model.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
之前的代码达到准确率0.8423
|
5 |
+
此代码达到准确率0.8379
|
6 |
+
此代码可行.
|
7 |
+
"""
|
8 |
+
import argparse
|
9 |
+
import copy
|
10 |
+
import json
|
11 |
+
import logging
|
12 |
+
from logging.handlers import TimedRotatingFileHandler
|
13 |
+
import os
|
14 |
+
from pathlib import Path
|
15 |
+
import platform
|
16 |
+
import sys
|
17 |
+
from typing import List
|
18 |
+
|
19 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
20 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
21 |
+
|
22 |
+
import torch
|
23 |
+
from torch.utils.data.dataloader import DataLoader
|
24 |
+
from tqdm import tqdm
|
25 |
+
|
26 |
+
from toolbox.torch.modules.loss import FocalLoss, HingeLoss, HingeLinear
|
27 |
+
from toolbox.torch.training.metrics.categorical_accuracy import CategoricalAccuracy
|
28 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
29 |
+
from toolbox.torch.utils.data.dataset.wave_classifier_excel_dataset import WaveClassifierExcelDataset
|
30 |
+
from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveEncoder, ClsHead, WaveClassifier
|
31 |
+
|
32 |
+
|
33 |
+
def get_args():
|
34 |
+
parser = argparse.ArgumentParser()
|
35 |
+
parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
|
36 |
+
|
37 |
+
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
38 |
+
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
39 |
+
|
40 |
+
parser.add_argument("--max_epochs", default=100, type=int)
|
41 |
+
parser.add_argument("--batch_size", default=64, type=int)
|
42 |
+
parser.add_argument("--learning_rate", default=1e-3, type=float)
|
43 |
+
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
44 |
+
parser.add_argument("--patience", default=5, type=int)
|
45 |
+
parser.add_argument("--serialization_dir", default="global_classifier", type=str)
|
46 |
+
parser.add_argument("--seed", default=0, type=int)
|
47 |
+
|
48 |
+
args = parser.parse_args()
|
49 |
+
return args
|
50 |
+
|
51 |
+
|
52 |
+
def logging_config(file_dir: str):
|
53 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
54 |
+
|
55 |
+
logging.basicConfig(format=fmt,
|
56 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
57 |
+
level=logging.DEBUG)
|
58 |
+
file_handler = TimedRotatingFileHandler(
|
59 |
+
filename=os.path.join(file_dir, "main.log"),
|
60 |
+
encoding="utf-8",
|
61 |
+
when="D",
|
62 |
+
interval=1,
|
63 |
+
backupCount=7
|
64 |
+
)
|
65 |
+
file_handler.setLevel(logging.INFO)
|
66 |
+
file_handler.setFormatter(logging.Formatter(fmt))
|
67 |
+
logger = logging.getLogger(__name__)
|
68 |
+
logger.addHandler(file_handler)
|
69 |
+
|
70 |
+
return logger
|
71 |
+
|
72 |
+
|
73 |
+
class CollateFunction(object):
|
74 |
+
def __init__(self):
|
75 |
+
pass
|
76 |
+
|
77 |
+
def __call__(self, batch: List[dict]):
|
78 |
+
array_list = list()
|
79 |
+
label_list = list()
|
80 |
+
for sample in batch:
|
81 |
+
array = sample["waveform"]
|
82 |
+
label = sample["label"]
|
83 |
+
|
84 |
+
array_list.append(array)
|
85 |
+
label_list.append(label)
|
86 |
+
|
87 |
+
array_list = torch.stack(array_list)
|
88 |
+
label_list = torch.stack(label_list)
|
89 |
+
return array_list, label_list
|
90 |
+
|
91 |
+
|
92 |
+
collate_fn = CollateFunction()
|
93 |
+
|
94 |
+
|
95 |
+
def main():
|
96 |
+
args = get_args()
|
97 |
+
|
98 |
+
serialization_dir = Path(args.serialization_dir)
|
99 |
+
serialization_dir.mkdir(parents=True, exist_ok=True)
|
100 |
+
|
101 |
+
logger = logging_config(args.serialization_dir)
|
102 |
+
|
103 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
104 |
+
n_gpu = torch.cuda.device_count()
|
105 |
+
logger.info("GPU available: {}; device: {}".format(n_gpu, device))
|
106 |
+
|
107 |
+
vocabulary = Vocabulary.from_files(args.vocabulary_dir)
|
108 |
+
|
109 |
+
# datasets
|
110 |
+
train_dataset = WaveClassifierExcelDataset(
|
111 |
+
vocab=vocabulary,
|
112 |
+
excel_file=args.train_dataset,
|
113 |
+
category=None,
|
114 |
+
category_field="category",
|
115 |
+
label_field="global_labels",
|
116 |
+
expected_sample_rate=8000,
|
117 |
+
max_wave_value=32768.0,
|
118 |
+
)
|
119 |
+
valid_dataset = WaveClassifierExcelDataset(
|
120 |
+
vocab=vocabulary,
|
121 |
+
excel_file=args.valid_dataset,
|
122 |
+
category=None,
|
123 |
+
category_field="category",
|
124 |
+
label_field="global_labels",
|
125 |
+
expected_sample_rate=8000,
|
126 |
+
max_wave_value=32768.0,
|
127 |
+
)
|
128 |
+
|
129 |
+
train_data_loader = DataLoader(
|
130 |
+
dataset=train_dataset,
|
131 |
+
batch_size=args.batch_size,
|
132 |
+
shuffle=True,
|
133 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
134 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count(),
|
135 |
+
collate_fn=collate_fn,
|
136 |
+
pin_memory=False,
|
137 |
+
# prefetch_factor=64,
|
138 |
+
)
|
139 |
+
valid_data_loader = DataLoader(
|
140 |
+
dataset=valid_dataset,
|
141 |
+
batch_size=args.batch_size,
|
142 |
+
shuffle=True,
|
143 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
144 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count(),
|
145 |
+
collate_fn=collate_fn,
|
146 |
+
pin_memory=False,
|
147 |
+
# prefetch_factor=64,
|
148 |
+
)
|
149 |
+
|
150 |
+
# models - classifier
|
151 |
+
wave_encoder = WaveEncoder(
|
152 |
+
conv1d_block_param_list=[
|
153 |
+
{
|
154 |
+
'batch_norm': True,
|
155 |
+
'in_channels': 80,
|
156 |
+
'out_channels': 16,
|
157 |
+
'kernel_size': 3,
|
158 |
+
'stride': 3,
|
159 |
+
# 'padding': 'same',
|
160 |
+
'activation': 'relu',
|
161 |
+
'dropout': 0.1,
|
162 |
+
},
|
163 |
+
{
|
164 |
+
# 'batch_norm': True,
|
165 |
+
'in_channels': 16,
|
166 |
+
'out_channels': 16,
|
167 |
+
'kernel_size': 3,
|
168 |
+
'stride': 3,
|
169 |
+
# 'padding': 'same',
|
170 |
+
'activation': 'relu',
|
171 |
+
'dropout': 0.1,
|
172 |
+
},
|
173 |
+
{
|
174 |
+
# 'batch_norm': True,
|
175 |
+
'in_channels': 16,
|
176 |
+
'out_channels': 16,
|
177 |
+
'kernel_size': 3,
|
178 |
+
'stride': 3,
|
179 |
+
# 'padding': 'same',
|
180 |
+
'activation': 'relu',
|
181 |
+
'dropout': 0.1,
|
182 |
+
},
|
183 |
+
],
|
184 |
+
mel_spectrogram_param={
|
185 |
+
"sample_rate": 8000,
|
186 |
+
"n_fft": 512,
|
187 |
+
"win_length": 200,
|
188 |
+
"hop_length": 80,
|
189 |
+
"f_min": 10,
|
190 |
+
"f_max": 3800,
|
191 |
+
"window_fn": "hamming",
|
192 |
+
"n_mels": 80,
|
193 |
+
}
|
194 |
+
)
|
195 |
+
cls_head = ClsHead(
|
196 |
+
input_dim=16,
|
197 |
+
num_layers=2,
|
198 |
+
hidden_dims=[32, 16],
|
199 |
+
activations="relu",
|
200 |
+
dropout=0.1,
|
201 |
+
num_labels=vocabulary.get_vocab_size(namespace="global_labels")
|
202 |
+
)
|
203 |
+
model = WaveClassifier(
|
204 |
+
wave_encoder=wave_encoder,
|
205 |
+
cls_head=cls_head,
|
206 |
+
)
|
207 |
+
model.to(device)
|
208 |
+
|
209 |
+
# optimizer
|
210 |
+
optimizer = torch.optim.Adam(
|
211 |
+
model.parameters(),
|
212 |
+
lr=args.learning_rate
|
213 |
+
)
|
214 |
+
lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
215 |
+
optimizer,
|
216 |
+
step_size=30000
|
217 |
+
)
|
218 |
+
focal_loss = FocalLoss(
|
219 |
+
num_classes=vocabulary.get_vocab_size(namespace="global_labels"),
|
220 |
+
reduction="mean",
|
221 |
+
)
|
222 |
+
categorical_accuracy = CategoricalAccuracy()
|
223 |
+
|
224 |
+
# training
|
225 |
+
best_idx_epoch: int = None
|
226 |
+
best_accuracy: float = None
|
227 |
+
patience_count = 0
|
228 |
+
global_step = 0
|
229 |
+
model_filename_list = list()
|
230 |
+
for idx_epoch in range(args.max_epochs):
|
231 |
+
|
232 |
+
# training
|
233 |
+
model.train()
|
234 |
+
total_loss = 0
|
235 |
+
total_examples = 0
|
236 |
+
for step, batch in enumerate(tqdm(train_data_loader, desc="Epoch={} (training)".format(idx_epoch))):
|
237 |
+
input_ids, label_ids = batch
|
238 |
+
input_ids = input_ids.to(device)
|
239 |
+
label_ids: torch.LongTensor = label_ids.to(device).long()
|
240 |
+
|
241 |
+
logits = model.forward(input_ids)
|
242 |
+
loss = focal_loss.forward(logits, label_ids.view(-1))
|
243 |
+
categorical_accuracy(logits, label_ids)
|
244 |
+
|
245 |
+
total_loss += loss.item()
|
246 |
+
total_examples += input_ids.size(0)
|
247 |
+
|
248 |
+
optimizer.zero_grad()
|
249 |
+
loss.backward()
|
250 |
+
optimizer.step()
|
251 |
+
lr_scheduler.step()
|
252 |
+
|
253 |
+
global_step += 1
|
254 |
+
training_loss = total_loss / total_examples
|
255 |
+
training_loss = round(training_loss, 4)
|
256 |
+
training_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
|
257 |
+
training_accuracy = round(training_accuracy, 4)
|
258 |
+
logger.info("Epoch: {}; training_loss: {}; training_accuracy: {}".format(
|
259 |
+
idx_epoch, training_loss, training_accuracy
|
260 |
+
))
|
261 |
+
|
262 |
+
# evaluation
|
263 |
+
model.eval()
|
264 |
+
total_loss = 0
|
265 |
+
total_examples = 0
|
266 |
+
for step, batch in enumerate(tqdm(valid_data_loader, desc="Epoch={} (evaluation)".format(idx_epoch))):
|
267 |
+
input_ids, label_ids = batch
|
268 |
+
input_ids = input_ids.to(device)
|
269 |
+
label_ids: torch.LongTensor = label_ids.to(device).long()
|
270 |
+
|
271 |
+
with torch.no_grad():
|
272 |
+
logits = model.forward(input_ids)
|
273 |
+
loss = focal_loss.forward(logits, label_ids.view(-1))
|
274 |
+
categorical_accuracy(logits, label_ids)
|
275 |
+
|
276 |
+
total_loss += loss.item()
|
277 |
+
total_examples += input_ids.size(0)
|
278 |
+
|
279 |
+
evaluation_loss = total_loss / total_examples
|
280 |
+
evaluation_loss = round(evaluation_loss, 4)
|
281 |
+
evaluation_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
|
282 |
+
evaluation_accuracy = round(evaluation_accuracy, 4)
|
283 |
+
logger.info("Epoch: {}; evaluation_loss: {}; evaluation_accuracy: {}".format(
|
284 |
+
idx_epoch, evaluation_loss, evaluation_accuracy
|
285 |
+
))
|
286 |
+
|
287 |
+
# save metric
|
288 |
+
metrics = {
|
289 |
+
"training_loss": training_loss,
|
290 |
+
"training_accuracy": training_accuracy,
|
291 |
+
"evaluation_loss": evaluation_loss,
|
292 |
+
"evaluation_accuracy": evaluation_accuracy,
|
293 |
+
"best_idx_epoch": best_idx_epoch,
|
294 |
+
"best_accuracy": best_accuracy,
|
295 |
+
}
|
296 |
+
metrics_filename = os.path.join(args.serialization_dir, "metrics_epoch_{}.json".format(idx_epoch))
|
297 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
298 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
299 |
+
|
300 |
+
# save model
|
301 |
+
model_filename = os.path.join(args.serialization_dir, "model_epoch_{}.bin".format(idx_epoch))
|
302 |
+
model_filename_list.append(model_filename)
|
303 |
+
if len(model_filename_list) >= args.num_serialized_models_to_keep:
|
304 |
+
model_filename_to_delete = model_filename_list.pop(0)
|
305 |
+
os.remove(model_filename_to_delete)
|
306 |
+
torch.save(model.state_dict(), model_filename)
|
307 |
+
|
308 |
+
# early stop
|
309 |
+
best_model_filename = os.path.join(args.serialization_dir, "best.bin")
|
310 |
+
if best_accuracy is None:
|
311 |
+
best_idx_epoch = idx_epoch
|
312 |
+
best_accuracy = evaluation_accuracy
|
313 |
+
torch.save(model.state_dict(), best_model_filename)
|
314 |
+
elif evaluation_accuracy > best_accuracy:
|
315 |
+
best_idx_epoch = idx_epoch
|
316 |
+
best_accuracy = evaluation_accuracy
|
317 |
+
torch.save(model.state_dict(), best_model_filename)
|
318 |
+
patience_count = 0
|
319 |
+
elif patience_count >= args.patience:
|
320 |
+
break
|
321 |
+
else:
|
322 |
+
patience_count += 1
|
323 |
+
|
324 |
+
return
|
325 |
+
|
326 |
+
|
327 |
+
if __name__ == "__main__":
|
328 |
+
main()
|
examples/vm_sound_classification8/step_4_train_country_model.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
只训练 cls_head 部分的参数, 模型的准确率会更低.
|
5 |
+
"""
|
6 |
+
import argparse
|
7 |
+
from collections import defaultdict
|
8 |
+
import json
|
9 |
+
import logging
|
10 |
+
from logging.handlers import TimedRotatingFileHandler
|
11 |
+
import os
|
12 |
+
import platform
|
13 |
+
from pathlib import Path
|
14 |
+
import sys
|
15 |
+
import shutil
|
16 |
+
from typing import List
|
17 |
+
|
18 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
19 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
20 |
+
|
21 |
+
import pandas as pd
|
22 |
+
import torch
|
23 |
+
from torch.utils.data.dataloader import DataLoader
|
24 |
+
from tqdm import tqdm
|
25 |
+
|
26 |
+
from toolbox.torch.modules.loss import FocalLoss, HingeLoss, HingeLinear
|
27 |
+
from toolbox.torch.training.metrics.categorical_accuracy import CategoricalAccuracy
|
28 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
29 |
+
from toolbox.torch.utils.data.dataset.wave_classifier_excel_dataset import WaveClassifierExcelDataset
|
30 |
+
from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveEncoder, ClsHead, WaveClassifier
|
31 |
+
|
32 |
+
|
33 |
+
def get_args():
|
34 |
+
parser = argparse.ArgumentParser()
|
35 |
+
parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
|
36 |
+
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
37 |
+
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
38 |
+
|
39 |
+
parser.add_argument("--country", default="en-US", type=str)
|
40 |
+
parser.add_argument("--shared_encoder", default="file_dir/global_model/best.bin", type=str)
|
41 |
+
|
42 |
+
parser.add_argument("--max_epochs", default=100, type=int)
|
43 |
+
parser.add_argument("--batch_size", default=64, type=int)
|
44 |
+
parser.add_argument("--learning_rate", default=1e-3, type=float)
|
45 |
+
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
46 |
+
parser.add_argument("--patience", default=5, type=int)
|
47 |
+
parser.add_argument("--serialization_dir", default="country_models", type=str)
|
48 |
+
parser.add_argument("--seed", default=0, type=int)
|
49 |
+
|
50 |
+
args = parser.parse_args()
|
51 |
+
return args
|
52 |
+
|
53 |
+
|
54 |
+
def logging_config(file_dir: str):
|
55 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
56 |
+
|
57 |
+
logging.basicConfig(format=fmt,
|
58 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
59 |
+
level=logging.DEBUG)
|
60 |
+
file_handler = TimedRotatingFileHandler(
|
61 |
+
filename=os.path.join(file_dir, "main.log"),
|
62 |
+
encoding="utf-8",
|
63 |
+
when="D",
|
64 |
+
interval=1,
|
65 |
+
backupCount=7
|
66 |
+
)
|
67 |
+
file_handler.setLevel(logging.INFO)
|
68 |
+
file_handler.setFormatter(logging.Formatter(fmt))
|
69 |
+
logger = logging.getLogger(__name__)
|
70 |
+
logger.addHandler(file_handler)
|
71 |
+
|
72 |
+
return logger
|
73 |
+
|
74 |
+
|
75 |
+
class CollateFunction(object):
|
76 |
+
def __init__(self):
|
77 |
+
pass
|
78 |
+
|
79 |
+
def __call__(self, batch: List[dict]):
|
80 |
+
array_list = list()
|
81 |
+
label_list = list()
|
82 |
+
for sample in batch:
|
83 |
+
array = sample['waveform']
|
84 |
+
label = sample['label']
|
85 |
+
|
86 |
+
array_list.append(array)
|
87 |
+
label_list.append(label)
|
88 |
+
|
89 |
+
array_list = torch.stack(array_list)
|
90 |
+
label_list = torch.stack(label_list)
|
91 |
+
return array_list, label_list
|
92 |
+
|
93 |
+
|
94 |
+
collate_fn = CollateFunction()
|
95 |
+
|
96 |
+
|
97 |
+
def main():
|
98 |
+
args = get_args()
|
99 |
+
|
100 |
+
serialization_dir = Path(args.serialization_dir)
|
101 |
+
serialization_dir.mkdir(parents=True, exist_ok=True)
|
102 |
+
|
103 |
+
logger = logging_config(args.serialization_dir)
|
104 |
+
|
105 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
106 |
+
n_gpu = torch.cuda.device_count()
|
107 |
+
logger.info("GPU available: {}; device: {}".format(n_gpu, device))
|
108 |
+
|
109 |
+
vocabulary = Vocabulary.from_files(args.vocabulary_dir)
|
110 |
+
|
111 |
+
# datasets
|
112 |
+
logger.info("prepare datasets")
|
113 |
+
train_dataset = WaveClassifierExcelDataset(
|
114 |
+
vocab=vocabulary,
|
115 |
+
excel_file=args.train_dataset,
|
116 |
+
category=args.country,
|
117 |
+
category_field="category",
|
118 |
+
label_field="country_labels",
|
119 |
+
expected_sample_rate=8000,
|
120 |
+
max_wave_value=32768.0,
|
121 |
+
)
|
122 |
+
valid_dataset = WaveClassifierExcelDataset(
|
123 |
+
vocab=vocabulary,
|
124 |
+
excel_file=args.valid_dataset,
|
125 |
+
category=args.country,
|
126 |
+
category_field="category",
|
127 |
+
label_field="country_labels",
|
128 |
+
expected_sample_rate=8000,
|
129 |
+
max_wave_value=32768.0,
|
130 |
+
)
|
131 |
+
|
132 |
+
train_data_loader = DataLoader(
|
133 |
+
dataset=train_dataset,
|
134 |
+
batch_size=args.batch_size,
|
135 |
+
shuffle=True,
|
136 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
137 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count(),
|
138 |
+
collate_fn=collate_fn,
|
139 |
+
pin_memory=False,
|
140 |
+
# prefetch_factor=64,
|
141 |
+
)
|
142 |
+
valid_data_loader = DataLoader(
|
143 |
+
dataset=valid_dataset,
|
144 |
+
batch_size=args.batch_size,
|
145 |
+
shuffle=True,
|
146 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
147 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count(),
|
148 |
+
collate_fn=collate_fn,
|
149 |
+
pin_memory=False,
|
150 |
+
# prefetch_factor=64,
|
151 |
+
)
|
152 |
+
|
153 |
+
# models - classifier
|
154 |
+
wave_encoder = WaveEncoder(
|
155 |
+
conv1d_block_param_list=[
|
156 |
+
{
|
157 |
+
'batch_norm': True,
|
158 |
+
'in_channels': 80,
|
159 |
+
'out_channels': 16,
|
160 |
+
'kernel_size': 3,
|
161 |
+
'stride': 3,
|
162 |
+
# 'padding': 'same',
|
163 |
+
'activation': 'relu',
|
164 |
+
'dropout': 0.1,
|
165 |
+
},
|
166 |
+
{
|
167 |
+
# 'batch_norm': True,
|
168 |
+
'in_channels': 16,
|
169 |
+
'out_channels': 16,
|
170 |
+
'kernel_size': 3,
|
171 |
+
'stride': 3,
|
172 |
+
# 'padding': 'same',
|
173 |
+
'activation': 'relu',
|
174 |
+
'dropout': 0.1,
|
175 |
+
},
|
176 |
+
{
|
177 |
+
# 'batch_norm': True,
|
178 |
+
'in_channels': 16,
|
179 |
+
'out_channels': 16,
|
180 |
+
'kernel_size': 3,
|
181 |
+
'stride': 3,
|
182 |
+
# 'padding': 'same',
|
183 |
+
'activation': 'relu',
|
184 |
+
'dropout': 0.1,
|
185 |
+
},
|
186 |
+
],
|
187 |
+
mel_spectrogram_param={
|
188 |
+
"sample_rate": 8000,
|
189 |
+
"n_fft": 512,
|
190 |
+
"win_length": 200,
|
191 |
+
"hop_length": 80,
|
192 |
+
"f_min": 10,
|
193 |
+
"f_max": 3800,
|
194 |
+
"window_fn": "hamming",
|
195 |
+
"n_mels": 80,
|
196 |
+
}
|
197 |
+
)
|
198 |
+
|
199 |
+
with open(args.shared_encoder, "rb") as f:
|
200 |
+
state_dict = torch.load(f, map_location=device)
|
201 |
+
processed_state_dict = dict()
|
202 |
+
prefix = "wave_encoder."
|
203 |
+
for k, v in state_dict.items():
|
204 |
+
if not str(k).startswith(prefix):
|
205 |
+
continue
|
206 |
+
k = k[len(prefix):]
|
207 |
+
processed_state_dict[k] = v
|
208 |
+
|
209 |
+
wave_encoder.load_state_dict(
|
210 |
+
state_dict=processed_state_dict,
|
211 |
+
strict=True,
|
212 |
+
)
|
213 |
+
cls_head = ClsHead(
|
214 |
+
input_dim=16,
|
215 |
+
num_layers=2,
|
216 |
+
hidden_dims=[32, 16],
|
217 |
+
activations="relu",
|
218 |
+
dropout=0.1,
|
219 |
+
num_labels=vocabulary.get_vocab_size(namespace="global_labels")
|
220 |
+
)
|
221 |
+
model = WaveClassifier(
|
222 |
+
wave_encoder=wave_encoder,
|
223 |
+
cls_head=cls_head,
|
224 |
+
)
|
225 |
+
model.wave_encoder.requires_grad_(requires_grad=False)
|
226 |
+
model.cls_head.requires_grad_(requires_grad=True)
|
227 |
+
model.to(device)
|
228 |
+
|
229 |
+
# optimizer
|
230 |
+
logger.info("prepare optimizer")
|
231 |
+
optimizer = torch.optim.Adam(
|
232 |
+
model.cls_head.parameters(),
|
233 |
+
lr=args.learning_rate,
|
234 |
+
)
|
235 |
+
lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
236 |
+
optimizer,
|
237 |
+
step_size=2000
|
238 |
+
)
|
239 |
+
focal_loss = FocalLoss(
|
240 |
+
num_classes=vocabulary.get_vocab_size(namespace=args.country),
|
241 |
+
reduction="mean",
|
242 |
+
)
|
243 |
+
categorical_accuracy = CategoricalAccuracy()
|
244 |
+
|
245 |
+
# training loop
|
246 |
+
best_idx_epoch: int = None
|
247 |
+
best_accuracy: float = None
|
248 |
+
patience_count = 0
|
249 |
+
global_step = 0
|
250 |
+
model_filename_list = list()
|
251 |
+
for idx_epoch in range(args.max_epochs):
|
252 |
+
|
253 |
+
# training
|
254 |
+
model.train()
|
255 |
+
total_loss = 0
|
256 |
+
total_examples = 0
|
257 |
+
for step, batch in enumerate(tqdm(train_data_loader, desc="Epoch={} (training)".format(idx_epoch))):
|
258 |
+
input_ids, label_ids = batch
|
259 |
+
input_ids = input_ids.to(device)
|
260 |
+
label_ids: torch.LongTensor = label_ids.to(device).long()
|
261 |
+
|
262 |
+
logits = model.forward(input_ids)
|
263 |
+
loss = focal_loss.forward(logits, label_ids.view(-1))
|
264 |
+
categorical_accuracy(logits, label_ids)
|
265 |
+
|
266 |
+
total_loss += loss.item()
|
267 |
+
total_examples += input_ids.size(0)
|
268 |
+
|
269 |
+
optimizer.zero_grad()
|
270 |
+
loss.backward()
|
271 |
+
optimizer.step()
|
272 |
+
lr_scheduler.step()
|
273 |
+
|
274 |
+
global_step += 1
|
275 |
+
training_loss = total_loss / total_examples
|
276 |
+
training_loss = round(training_loss, 4)
|
277 |
+
training_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
|
278 |
+
training_accuracy = round(training_accuracy, 4)
|
279 |
+
logger.info("Epoch: {}; training_loss: {}; training_accuracy: {}".format(
|
280 |
+
idx_epoch, training_loss, training_accuracy
|
281 |
+
))
|
282 |
+
|
283 |
+
# evaluation
|
284 |
+
model.eval()
|
285 |
+
total_loss = 0
|
286 |
+
total_examples = 0
|
287 |
+
for step, batch in enumerate(tqdm(valid_data_loader, desc="Epoch={} (evaluation)".format(idx_epoch))):
|
288 |
+
input_ids, label_ids = batch
|
289 |
+
input_ids = input_ids.to(device)
|
290 |
+
label_ids: torch.LongTensor = label_ids.to(device).long()
|
291 |
+
|
292 |
+
with torch.no_grad():
|
293 |
+
logits = model.forward(input_ids)
|
294 |
+
loss = focal_loss.forward(logits, label_ids.view(-1))
|
295 |
+
categorical_accuracy(logits, label_ids)
|
296 |
+
|
297 |
+
total_loss += loss.item()
|
298 |
+
total_examples += input_ids.size(0)
|
299 |
+
|
300 |
+
evaluation_loss = total_loss / total_examples
|
301 |
+
evaluation_loss = round(evaluation_loss, 4)
|
302 |
+
evaluation_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
|
303 |
+
evaluation_accuracy = round(evaluation_accuracy, 4)
|
304 |
+
logger.info("Epoch: {}; evaluation_loss: {}; evaluation_accuracy: {}".format(
|
305 |
+
idx_epoch, evaluation_loss, evaluation_accuracy
|
306 |
+
))
|
307 |
+
|
308 |
+
# save metric
|
309 |
+
metrics = {
|
310 |
+
"training_loss": training_loss,
|
311 |
+
"training_accuracy": training_accuracy,
|
312 |
+
"evaluation_loss": evaluation_loss,
|
313 |
+
"evaluation_accuracy": evaluation_accuracy,
|
314 |
+
"best_idx_epoch": best_idx_epoch,
|
315 |
+
"best_accuracy": best_accuracy,
|
316 |
+
}
|
317 |
+
metrics_filename = os.path.join(args.serialization_dir, "metrics_epoch_{}.json".format(idx_epoch))
|
318 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
319 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
320 |
+
|
321 |
+
# save model
|
322 |
+
model_filename = os.path.join(args.serialization_dir, "model_epoch_{}.bin".format(idx_epoch))
|
323 |
+
model_filename_list.append(model_filename)
|
324 |
+
if len(model_filename_list) >= args.num_serialized_models_to_keep:
|
325 |
+
model_filename_to_delete = model_filename_list.pop(0)
|
326 |
+
os.remove(model_filename_to_delete)
|
327 |
+
torch.save(model.state_dict(), model_filename)
|
328 |
+
|
329 |
+
# early stop
|
330 |
+
best_model_filename = os.path.join(args.serialization_dir, "best.bin")
|
331 |
+
if best_accuracy is None:
|
332 |
+
best_idx_epoch = idx_epoch
|
333 |
+
best_accuracy = evaluation_accuracy
|
334 |
+
torch.save(model.state_dict(), best_model_filename)
|
335 |
+
elif evaluation_accuracy > best_accuracy:
|
336 |
+
best_idx_epoch = idx_epoch
|
337 |
+
best_accuracy = evaluation_accuracy
|
338 |
+
torch.save(model.state_dict(), best_model_filename)
|
339 |
+
patience_count = 0
|
340 |
+
elif patience_count >= args.patience:
|
341 |
+
break
|
342 |
+
else:
|
343 |
+
patience_count += 1
|
344 |
+
|
345 |
+
return
|
346 |
+
|
347 |
+
|
348 |
+
if __name__ == "__main__":
|
349 |
+
main()
|
examples/vm_sound_classification8/step_5_train_union.py
ADDED
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
from collections import defaultdict
|
5 |
+
import json
|
6 |
+
import logging
|
7 |
+
from logging.handlers import TimedRotatingFileHandler
|
8 |
+
import os
|
9 |
+
import platform
|
10 |
+
from pathlib import Path
|
11 |
+
import sys
|
12 |
+
import shutil
|
13 |
+
from typing import List
|
14 |
+
|
15 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
16 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
17 |
+
|
18 |
+
import pandas as pd
|
19 |
+
import torch
|
20 |
+
from torch.utils.data.dataloader import DataLoader
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
from toolbox.torch.modules.loss import FocalLoss, HingeLoss, HingeLinear
|
24 |
+
from toolbox.torch.training.metrics.categorical_accuracy import CategoricalAccuracy
|
25 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
26 |
+
from toolbox.torch.utils.data.dataset.wave_classifier_excel_dataset import WaveClassifierExcelDataset
|
27 |
+
from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveEncoder, ClsHead, WaveClassifier
|
28 |
+
|
29 |
+
|
30 |
+
def get_args():
|
31 |
+
parser = argparse.ArgumentParser()
|
32 |
+
parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
|
33 |
+
|
34 |
+
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
35 |
+
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
36 |
+
|
37 |
+
parser.add_argument("--max_steps", default=100000, type=int)
|
38 |
+
parser.add_argument("--save_steps", default=30, type=int)
|
39 |
+
|
40 |
+
parser.add_argument("--batch_size", default=1, type=int)
|
41 |
+
parser.add_argument("--learning_rate", default=1e-3, type=float)
|
42 |
+
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
43 |
+
parser.add_argument("--patience", default=5, type=int)
|
44 |
+
parser.add_argument("--serialization_dir", default="union", type=str)
|
45 |
+
parser.add_argument("--seed", default=0, type=int)
|
46 |
+
|
47 |
+
parser.add_argument("--num_workers", default=0, type=int)
|
48 |
+
|
49 |
+
args = parser.parse_args()
|
50 |
+
return args
|
51 |
+
|
52 |
+
|
53 |
+
def logging_config(file_dir: str):
|
54 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
55 |
+
|
56 |
+
logging.basicConfig(format=fmt,
|
57 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
58 |
+
level=logging.DEBUG)
|
59 |
+
file_handler = TimedRotatingFileHandler(
|
60 |
+
filename=os.path.join(file_dir, "main.log"),
|
61 |
+
encoding="utf-8",
|
62 |
+
when="D",
|
63 |
+
interval=1,
|
64 |
+
backupCount=7
|
65 |
+
)
|
66 |
+
file_handler.setLevel(logging.INFO)
|
67 |
+
file_handler.setFormatter(logging.Formatter(fmt))
|
68 |
+
logger = logging.getLogger(__name__)
|
69 |
+
logger.addHandler(file_handler)
|
70 |
+
|
71 |
+
return logger
|
72 |
+
|
73 |
+
|
74 |
+
class CollateFunction(object):
|
75 |
+
def __init__(self):
|
76 |
+
pass
|
77 |
+
|
78 |
+
def __call__(self, batch: List[dict]):
|
79 |
+
array_list = list()
|
80 |
+
label_list = list()
|
81 |
+
for sample in batch:
|
82 |
+
array = sample['waveform']
|
83 |
+
label = sample['label']
|
84 |
+
|
85 |
+
array_list.append(array)
|
86 |
+
label_list.append(label)
|
87 |
+
|
88 |
+
array_list = torch.stack(array_list)
|
89 |
+
label_list = torch.stack(label_list)
|
90 |
+
return array_list, label_list
|
91 |
+
|
92 |
+
|
93 |
+
collate_fn = CollateFunction()
|
94 |
+
|
95 |
+
|
96 |
+
class DatasetIterator(object):
|
97 |
+
def __init__(self, data_loader: DataLoader):
|
98 |
+
self.data_loader = data_loader
|
99 |
+
self.data_loader_iter = iter(self.data_loader)
|
100 |
+
|
101 |
+
def next(self):
|
102 |
+
try:
|
103 |
+
result = self.data_loader_iter.__next__()
|
104 |
+
except StopIteration:
|
105 |
+
self.data_loader_iter = iter(self.data_loader)
|
106 |
+
result = self.data_loader_iter.__next__()
|
107 |
+
return result
|
108 |
+
|
109 |
+
|
110 |
+
def main():
|
111 |
+
args = get_args()
|
112 |
+
|
113 |
+
serialization_dir = Path(args.serialization_dir)
|
114 |
+
serialization_dir.mkdir(parents=True, exist_ok=True)
|
115 |
+
|
116 |
+
logger = logging_config(args.serialization_dir)
|
117 |
+
|
118 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
119 |
+
n_gpu = torch.cuda.device_count()
|
120 |
+
logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
|
121 |
+
|
122 |
+
vocabulary = Vocabulary.from_files(args.vocabulary_dir)
|
123 |
+
namespaces = vocabulary._token_to_index.keys()
|
124 |
+
|
125 |
+
# namespace_to_ratio
|
126 |
+
max_radio = (len(namespaces) - 1) * 3
|
127 |
+
namespace_to_ratio = {n: 1 for n in namespaces}
|
128 |
+
namespace_to_ratio["global_labels"] = max_radio
|
129 |
+
|
130 |
+
# datasets
|
131 |
+
logger.info("prepare datasets")
|
132 |
+
namespace_to_datasets = dict()
|
133 |
+
for namespace in namespaces:
|
134 |
+
logger.info("prepare datasets - {}".format(namespace))
|
135 |
+
if namespace == "global_labels":
|
136 |
+
train_dataset = WaveClassifierExcelDataset(
|
137 |
+
vocab=vocabulary,
|
138 |
+
excel_file=args.train_dataset,
|
139 |
+
category=None,
|
140 |
+
category_field="category",
|
141 |
+
label_field="global_labels",
|
142 |
+
expected_sample_rate=8000,
|
143 |
+
max_wave_value=32768.0,
|
144 |
+
)
|
145 |
+
valid_dataset = WaveClassifierExcelDataset(
|
146 |
+
vocab=vocabulary,
|
147 |
+
excel_file=args.valid_dataset,
|
148 |
+
category=None,
|
149 |
+
category_field="category",
|
150 |
+
label_field="global_labels",
|
151 |
+
expected_sample_rate=8000,
|
152 |
+
max_wave_value=32768.0,
|
153 |
+
)
|
154 |
+
else:
|
155 |
+
train_dataset = WaveClassifierExcelDataset(
|
156 |
+
vocab=vocabulary,
|
157 |
+
excel_file=args.train_dataset,
|
158 |
+
category=namespace,
|
159 |
+
category_field="category",
|
160 |
+
label_field="country_labels",
|
161 |
+
expected_sample_rate=8000,
|
162 |
+
max_wave_value=32768.0,
|
163 |
+
)
|
164 |
+
valid_dataset = WaveClassifierExcelDataset(
|
165 |
+
vocab=vocabulary,
|
166 |
+
excel_file=args.valid_dataset,
|
167 |
+
category=namespace,
|
168 |
+
category_field="category",
|
169 |
+
label_field="country_labels",
|
170 |
+
expected_sample_rate=8000,
|
171 |
+
max_wave_value=32768.0,
|
172 |
+
)
|
173 |
+
|
174 |
+
train_data_loader = DataLoader(
|
175 |
+
dataset=train_dataset,
|
176 |
+
batch_size=args.batch_size,
|
177 |
+
shuffle=True,
|
178 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
179 |
+
# num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
180 |
+
num_workers=args.num_workers,
|
181 |
+
collate_fn=collate_fn,
|
182 |
+
pin_memory=False,
|
183 |
+
# prefetch_factor=64,
|
184 |
+
)
|
185 |
+
valid_data_loader = DataLoader(
|
186 |
+
dataset=valid_dataset,
|
187 |
+
batch_size=args.batch_size,
|
188 |
+
shuffle=True,
|
189 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
190 |
+
# num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
191 |
+
num_workers=args.num_workers,
|
192 |
+
collate_fn=collate_fn,
|
193 |
+
pin_memory=False,
|
194 |
+
# prefetch_factor=64,
|
195 |
+
)
|
196 |
+
|
197 |
+
namespace_to_datasets[namespace] = {
|
198 |
+
"train_data_loader": train_data_loader,
|
199 |
+
"valid_data_loader": valid_data_loader,
|
200 |
+
}
|
201 |
+
|
202 |
+
# datasets iterator
|
203 |
+
logger.info("prepare datasets iterator")
|
204 |
+
namespace_to_datasets_iter = dict()
|
205 |
+
for namespace in namespaces:
|
206 |
+
logger.info("prepare datasets iterator - {}".format(namespace))
|
207 |
+
train_data_loader = namespace_to_datasets[namespace]["train_data_loader"]
|
208 |
+
valid_data_loader = namespace_to_datasets[namespace]["valid_data_loader"]
|
209 |
+
namespace_to_datasets_iter[namespace] = {
|
210 |
+
"train_data_loader_iter": DatasetIterator(train_data_loader),
|
211 |
+
"valid_data_loader_iter": DatasetIterator(valid_data_loader),
|
212 |
+
}
|
213 |
+
|
214 |
+
# models - encoder
|
215 |
+
logger.info("prepare models - encoder")
|
216 |
+
wave_encoder = WaveEncoder(
|
217 |
+
conv2d_block_param_list=[
|
218 |
+
{
|
219 |
+
"batch_norm": True,
|
220 |
+
"in_channels": 1,
|
221 |
+
"out_channels": 4,
|
222 |
+
"kernel_size": 3,
|
223 |
+
"stride": 1,
|
224 |
+
# "padding": "same",
|
225 |
+
"dilation": 3,
|
226 |
+
"activation": "relu",
|
227 |
+
"dropout": 0.1,
|
228 |
+
},
|
229 |
+
{
|
230 |
+
# "batch_norm": True,
|
231 |
+
"in_channels": 4,
|
232 |
+
"out_channels": 4,
|
233 |
+
"kernel_size": 5,
|
234 |
+
"stride": 2,
|
235 |
+
# "padding": "same",
|
236 |
+
"dilation": 3,
|
237 |
+
"activation": "relu",
|
238 |
+
"dropout": 0.1,
|
239 |
+
},
|
240 |
+
{
|
241 |
+
# "batch_norm": True,
|
242 |
+
"in_channels": 4,
|
243 |
+
"out_channels": 4,
|
244 |
+
"kernel_size": 3,
|
245 |
+
"stride": 1,
|
246 |
+
# "padding": "same",
|
247 |
+
"dilation": 2,
|
248 |
+
"activation": "relu",
|
249 |
+
"dropout": 0.1,
|
250 |
+
},
|
251 |
+
],
|
252 |
+
mel_spectrogram_param={
|
253 |
+
'sample_rate': 8000,
|
254 |
+
'n_fft': 512,
|
255 |
+
'win_length': 200,
|
256 |
+
'hop_length': 80,
|
257 |
+
'f_min': 10,
|
258 |
+
'f_max': 3800,
|
259 |
+
'window_fn': 'hamming',
|
260 |
+
'n_mels': 80,
|
261 |
+
}
|
262 |
+
)
|
263 |
+
|
264 |
+
# models - cls_head
|
265 |
+
logger.info("prepare models - cls_head")
|
266 |
+
namespace_to_cls_heads = dict()
|
267 |
+
for namespace in namespaces:
|
268 |
+
logger.info("prepare models - cls_head - {}".format(namespace))
|
269 |
+
cls_head = ClsHead(
|
270 |
+
input_dim=352,
|
271 |
+
num_layers=2,
|
272 |
+
hidden_dims=[128, 32],
|
273 |
+
activations="relu",
|
274 |
+
dropout=0.1,
|
275 |
+
num_labels=vocabulary.get_vocab_size(namespace=namespace)
|
276 |
+
)
|
277 |
+
namespace_to_cls_heads[namespace] = cls_head
|
278 |
+
|
279 |
+
# models - classifier
|
280 |
+
logger.info("prepare models - classifier")
|
281 |
+
namespace_to_classifier = dict()
|
282 |
+
for namespace in namespaces:
|
283 |
+
logger.info("prepare models - classifier - {}".format(namespace))
|
284 |
+
cls_head = namespace_to_cls_heads[namespace]
|
285 |
+
wave_classifier = WaveClassifier(
|
286 |
+
wave_encoder=wave_encoder,
|
287 |
+
cls_head=cls_head,
|
288 |
+
)
|
289 |
+
wave_classifier.to(device)
|
290 |
+
namespace_to_classifier[namespace] = wave_classifier
|
291 |
+
|
292 |
+
# optimizer
|
293 |
+
logger.info("prepare optimizer")
|
294 |
+
param_optimizer = list()
|
295 |
+
param_optimizer.extend(wave_encoder.parameters())
|
296 |
+
for _, cls_head in namespace_to_cls_heads.items():
|
297 |
+
param_optimizer.extend(cls_head.parameters())
|
298 |
+
|
299 |
+
optimizer = torch.optim.Adam(
|
300 |
+
param_optimizer,
|
301 |
+
lr=args.learning_rate,
|
302 |
+
)
|
303 |
+
lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
304 |
+
optimizer,
|
305 |
+
step_size=10000
|
306 |
+
)
|
307 |
+
focal_loss = FocalLoss(
|
308 |
+
num_classes=vocabulary.get_vocab_size(namespace="global_labels"),
|
309 |
+
reduction="mean",
|
310 |
+
)
|
311 |
+
|
312 |
+
# categorical_accuracy
|
313 |
+
logger.info("prepare categorical_accuracy")
|
314 |
+
namespace_to_categorical_accuracy = dict()
|
315 |
+
for namespace in namespaces:
|
316 |
+
categorical_accuracy = CategoricalAccuracy()
|
317 |
+
namespace_to_categorical_accuracy[namespace] = categorical_accuracy
|
318 |
+
|
319 |
+
# training loop
|
320 |
+
logger.info("prepare training loop")
|
321 |
+
|
322 |
+
model_list = list()
|
323 |
+
best_idx_step = None
|
324 |
+
best_accuracy = None
|
325 |
+
patience_count = 0
|
326 |
+
|
327 |
+
namespace_to_total_loss = defaultdict(float)
|
328 |
+
namespace_to_total_examples = defaultdict(int)
|
329 |
+
for idx_step in tqdm(range(args.max_steps)):
|
330 |
+
|
331 |
+
# training one step
|
332 |
+
loss: torch.Tensor = None
|
333 |
+
for namespace in namespaces:
|
334 |
+
train_data_loader_iter = namespace_to_datasets_iter[namespace]["train_data_loader_iter"]
|
335 |
+
|
336 |
+
ratio = namespace_to_ratio[namespace]
|
337 |
+
model = namespace_to_classifier[namespace]
|
338 |
+
categorical_accuracy = namespace_to_categorical_accuracy[namespace]
|
339 |
+
|
340 |
+
model.train()
|
341 |
+
|
342 |
+
for _ in range(ratio):
|
343 |
+
batch = train_data_loader_iter.next()
|
344 |
+
input_ids, label_ids = batch
|
345 |
+
input_ids = input_ids.to(device)
|
346 |
+
label_ids: torch.LongTensor = label_ids.to(device).long()
|
347 |
+
|
348 |
+
logits = model.forward(input_ids)
|
349 |
+
task_loss = focal_loss.forward(logits, label_ids.view(-1))
|
350 |
+
categorical_accuracy(logits, label_ids)
|
351 |
+
|
352 |
+
if loss is None:
|
353 |
+
loss = task_loss
|
354 |
+
else:
|
355 |
+
loss += task_loss
|
356 |
+
|
357 |
+
namespace_to_total_loss[namespace] += task_loss.item()
|
358 |
+
namespace_to_total_examples[namespace] += input_ids.size(0)
|
359 |
+
|
360 |
+
optimizer.zero_grad()
|
361 |
+
loss.backward()
|
362 |
+
optimizer.step()
|
363 |
+
lr_scheduler.step()
|
364 |
+
|
365 |
+
# logging
|
366 |
+
if (idx_step + 1) % args.save_steps == 0:
|
367 |
+
metrics = dict()
|
368 |
+
|
369 |
+
# training
|
370 |
+
for namespace in namespaces:
|
371 |
+
total_loss = namespace_to_total_loss[namespace]
|
372 |
+
total_examples = namespace_to_total_examples[namespace]
|
373 |
+
|
374 |
+
training_loss = total_loss / total_examples
|
375 |
+
training_loss = round(training_loss, 4)
|
376 |
+
|
377 |
+
categorical_accuracy = namespace_to_categorical_accuracy[namespace]
|
378 |
+
|
379 |
+
training_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
|
380 |
+
training_accuracy = round(training_accuracy, 4)
|
381 |
+
logger.info("Step: {}; namespace: {}; training_loss: {}; training_accuracy: {}".format(
|
382 |
+
idx_step, namespace, training_loss, training_accuracy
|
383 |
+
))
|
384 |
+
metrics[namespace] = {
|
385 |
+
"training_loss": training_loss,
|
386 |
+
"training_accuracy": training_accuracy,
|
387 |
+
}
|
388 |
+
namespace_to_total_loss = defaultdict(float)
|
389 |
+
namespace_to_total_examples = defaultdict(int)
|
390 |
+
|
391 |
+
# evaluation
|
392 |
+
for namespace in namespaces:
|
393 |
+
valid_data_loader = namespace_to_datasets[namespace]["valid_data_loader"]
|
394 |
+
|
395 |
+
model = namespace_to_classifier[namespace]
|
396 |
+
categorical_accuracy = namespace_to_categorical_accuracy[namespace]
|
397 |
+
|
398 |
+
model.eval()
|
399 |
+
|
400 |
+
total_loss = 0
|
401 |
+
total_examples = 0
|
402 |
+
for step, batch in enumerate(valid_data_loader):
|
403 |
+
input_ids, label_ids = batch
|
404 |
+
input_ids = input_ids.to(device)
|
405 |
+
label_ids: torch.LongTensor = label_ids.to(device).long()
|
406 |
+
|
407 |
+
with torch.no_grad():
|
408 |
+
logits = model.forward(input_ids)
|
409 |
+
loss = focal_loss.forward(logits, label_ids.view(-1))
|
410 |
+
categorical_accuracy(logits, label_ids)
|
411 |
+
|
412 |
+
total_loss += loss.item()
|
413 |
+
total_examples += input_ids.size(0)
|
414 |
+
|
415 |
+
evaluation_loss = total_loss / total_examples
|
416 |
+
evaluation_loss = round(evaluation_loss, 4)
|
417 |
+
evaluation_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
|
418 |
+
evaluation_accuracy = round(evaluation_accuracy, 4)
|
419 |
+
logger.info("Step: {}; namespace: {}; evaluation_loss: {}; evaluation_accuracy: {}".format(
|
420 |
+
idx_step, namespace, evaluation_loss, evaluation_accuracy
|
421 |
+
))
|
422 |
+
metrics[namespace] = {
|
423 |
+
"evaluation_loss": evaluation_loss,
|
424 |
+
"evaluation_accuracy": evaluation_accuracy,
|
425 |
+
}
|
426 |
+
|
427 |
+
# update ratio
|
428 |
+
min_accuracy = min([m["evaluation_accuracy"] for m in metrics.values()])
|
429 |
+
max_accuracy = max([m["evaluation_accuracy"] for m in metrics.values()])
|
430 |
+
width = max_accuracy - min_accuracy
|
431 |
+
for namespace, metric in metrics.items():
|
432 |
+
evaluation_accuracy = metric["evaluation_accuracy"]
|
433 |
+
radio = (max_accuracy - evaluation_accuracy) / width * max_radio
|
434 |
+
radio = int(radio)
|
435 |
+
namespace_to_ratio[namespace] = radio
|
436 |
+
|
437 |
+
msg = "".join(["{}: {}; ".format(k, v) for k, v in namespace_to_ratio.items()])
|
438 |
+
logger.info("namespace to ratio: {}".format(msg))
|
439 |
+
|
440 |
+
# save path
|
441 |
+
step_dir = serialization_dir / "step-{}".format(idx_step)
|
442 |
+
step_dir.mkdir(parents=True, exist_ok=False)
|
443 |
+
|
444 |
+
# save models
|
445 |
+
wave_encoder_filename = step_dir / "wave_encoder.pt"
|
446 |
+
torch.save(wave_encoder.state_dict(), wave_encoder_filename)
|
447 |
+
for namespace in namespaces:
|
448 |
+
cls_head_filename = step_dir / "{}.pt".format(namespace)
|
449 |
+
cls_head = namespace_to_cls_heads[namespace]
|
450 |
+
torch.save(cls_head.state_dict(), cls_head_filename)
|
451 |
+
|
452 |
+
model_list.append(step_dir)
|
453 |
+
if len(model_list) >= args.num_serialized_models_to_keep:
|
454 |
+
model_to_delete: Path = model_list.pop(0)
|
455 |
+
shutil.rmtree(model_to_delete.as_posix())
|
456 |
+
|
457 |
+
# save metric
|
458 |
+
this_accuracy = metrics["global_labels"]["evaluation_accuracy"]
|
459 |
+
if best_accuracy is None:
|
460 |
+
best_idx_step = idx_step
|
461 |
+
best_accuracy = this_accuracy
|
462 |
+
elif metrics["global_labels"]["evaluation_accuracy"] > best_accuracy:
|
463 |
+
best_idx_step = idx_step
|
464 |
+
best_accuracy = this_accuracy
|
465 |
+
else:
|
466 |
+
pass
|
467 |
+
|
468 |
+
metrics_filename = step_dir / "metrics_epoch.json"
|
469 |
+
metrics.update({
|
470 |
+
"idx_step": idx_step,
|
471 |
+
"best_idx_step": best_idx_step,
|
472 |
+
})
|
473 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
474 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
475 |
+
|
476 |
+
# save best
|
477 |
+
best_dir = serialization_dir / "best"
|
478 |
+
if best_idx_step == idx_step:
|
479 |
+
if best_dir.exists():
|
480 |
+
shutil.rmtree(best_dir)
|
481 |
+
shutil.copytree(step_dir, best_dir)
|
482 |
+
|
483 |
+
# early stop
|
484 |
+
early_stop_flag = False
|
485 |
+
if best_idx_step == idx_step:
|
486 |
+
patience_count = 0
|
487 |
+
else:
|
488 |
+
patience_count += 1
|
489 |
+
if patience_count >= args.patience:
|
490 |
+
early_stop_flag = True
|
491 |
+
|
492 |
+
# early stop
|
493 |
+
if early_stop_flag:
|
494 |
+
break
|
495 |
+
return
|
496 |
+
|
497 |
+
|
498 |
+
if __name__ == "__main__":
|
499 |
+
main()
|
examples/vm_sound_classification8/stop.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
kill -9 `ps -aef | grep 'vm_sound_classification/bin/python3' | grep -v grep | awk '{print $2}' | sed 's/\n/ /'`
|
install.sh
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# bash install.sh --stage 2 --stop_stage 2 --system_version centos
|
4 |
+
|
5 |
+
|
6 |
+
python_version=3.8.10
|
7 |
+
system_version="centos";
|
8 |
+
|
9 |
+
verbose=true;
|
10 |
+
stage=-1
|
11 |
+
stop_stage=0
|
12 |
+
|
13 |
+
|
14 |
+
# parse options
|
15 |
+
while true; do
|
16 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
17 |
+
case "$1" in
|
18 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
19 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
20 |
+
old_value="(eval echo \\$$name)";
|
21 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
22 |
+
was_bool=true;
|
23 |
+
else
|
24 |
+
was_bool=false;
|
25 |
+
fi
|
26 |
+
|
27 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
28 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
29 |
+
eval "${name}=\"$2\"";
|
30 |
+
|
31 |
+
# Check that Boolean-valued arguments are really Boolean.
|
32 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
33 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
34 |
+
exit 1;
|
35 |
+
fi
|
36 |
+
shift 2;
|
37 |
+
;;
|
38 |
+
|
39 |
+
*) break;
|
40 |
+
esac
|
41 |
+
done
|
42 |
+
|
43 |
+
work_dir="$(pwd)"
|
44 |
+
|
45 |
+
|
46 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
47 |
+
$verbose && echo "stage 1: install python"
|
48 |
+
cd "${work_dir}" || exit 1;
|
49 |
+
|
50 |
+
sh ./script/install_python.sh --python_version "${python_version}" --system_version "${system_version}"
|
51 |
+
fi
|
52 |
+
|
53 |
+
|
54 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
55 |
+
$verbose && echo "stage 2: create virtualenv"
|
56 |
+
|
57 |
+
# /usr/local/python-3.6.5/bin/virtualenv vm_sound_classification
|
58 |
+
# source /data/local/bin/vm_sound_classification/bin/activate
|
59 |
+
/usr/local/python-${python_version}/bin/pip3 install virtualenv
|
60 |
+
mkdir -p /data/local/bin
|
61 |
+
cd /data/local/bin || exit 1;
|
62 |
+
/usr/local/python-${python_version}/bin/virtualenv vm_sound_classification
|
63 |
+
|
64 |
+
fi
|
main.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
from functools import lru_cache
|
5 |
+
import json
|
6 |
+
from pathlib import Path
|
7 |
+
import platform
|
8 |
+
import shutil
|
9 |
+
import tempfile
|
10 |
+
import zipfile
|
11 |
+
|
12 |
+
import gradio as gr
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
|
16 |
+
from project_settings import environment, project_path
|
17 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
18 |
+
|
19 |
+
|
20 |
+
def get_args():
|
21 |
+
parser = argparse.ArgumentParser()
|
22 |
+
parser.add_argument(
|
23 |
+
"--examples_dir",
|
24 |
+
default=(project_path / "data/examples").as_posix(),
|
25 |
+
type=str
|
26 |
+
)
|
27 |
+
parser.add_argument(
|
28 |
+
"--trained_model_dir",
|
29 |
+
default=(project_path / "trained_models").as_posix(),
|
30 |
+
type=str
|
31 |
+
)
|
32 |
+
parser.add_argument(
|
33 |
+
"--server_port",
|
34 |
+
default=environment.get("server_port", 7860),
|
35 |
+
type=int
|
36 |
+
)
|
37 |
+
args = parser.parse_args()
|
38 |
+
return args
|
39 |
+
|
40 |
+
|
41 |
+
@lru_cache(maxsize=100)
|
42 |
+
def load_model(model_file: Path):
|
43 |
+
with zipfile.ZipFile(model_file, "r") as f_zip:
|
44 |
+
out_root = Path(tempfile.gettempdir()) / "vm_sound_classification"
|
45 |
+
if out_root.exists():
|
46 |
+
shutil.rmtree(out_root.as_posix())
|
47 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
48 |
+
f_zip.extractall(path=out_root)
|
49 |
+
|
50 |
+
tgt_path = out_root / model_file.stem
|
51 |
+
jit_model_file = tgt_path / "trace_model.zip"
|
52 |
+
vocab_path = tgt_path / "vocabulary"
|
53 |
+
|
54 |
+
vocabulary = Vocabulary.from_files(vocab_path.as_posix())
|
55 |
+
|
56 |
+
with open(jit_model_file.as_posix(), "rb") as f:
|
57 |
+
model = torch.jit.load(f)
|
58 |
+
model.eval()
|
59 |
+
|
60 |
+
shutil.rmtree(tgt_path)
|
61 |
+
|
62 |
+
d = {
|
63 |
+
"model": model,
|
64 |
+
"vocabulary": vocabulary
|
65 |
+
}
|
66 |
+
return d
|
67 |
+
|
68 |
+
|
69 |
+
def click_button(audio: np.ndarray,
|
70 |
+
model_name: str,
|
71 |
+
ground_true: str) -> str:
|
72 |
+
|
73 |
+
sample_rate, signal = audio
|
74 |
+
|
75 |
+
model_file = "trained_models/{}.zip".format(model_name)
|
76 |
+
model_file = Path(model_file)
|
77 |
+
d = load_model(model_file)
|
78 |
+
|
79 |
+
model = d["model"]
|
80 |
+
vocabulary = d["vocabulary"]
|
81 |
+
|
82 |
+
inputs = signal / (1 << 15)
|
83 |
+
inputs = torch.tensor(inputs, dtype=torch.float32)
|
84 |
+
inputs = torch.unsqueeze(inputs, dim=0)
|
85 |
+
|
86 |
+
with torch.no_grad():
|
87 |
+
logits = model.forward(inputs)
|
88 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
89 |
+
label_idx = torch.argmax(probs, dim=-1)
|
90 |
+
|
91 |
+
label_idx = label_idx.cpu()
|
92 |
+
probs = probs.cpu()
|
93 |
+
|
94 |
+
label_idx = label_idx.numpy()[0]
|
95 |
+
prob = probs.numpy()[0][label_idx]
|
96 |
+
|
97 |
+
label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
|
98 |
+
|
99 |
+
return label_str, round(prob, 4)
|
100 |
+
|
101 |
+
|
102 |
+
def main():
|
103 |
+
args = get_args()
|
104 |
+
|
105 |
+
examples_dir = Path(args.examples_dir)
|
106 |
+
trained_model_dir = Path(args.trained_model_dir)
|
107 |
+
|
108 |
+
# examples
|
109 |
+
examples = list()
|
110 |
+
for filename in examples_dir.glob("*/*/*.wav"):
|
111 |
+
language = filename.parts[-3]
|
112 |
+
label = filename.parts[-2]
|
113 |
+
|
114 |
+
examples.append([
|
115 |
+
filename.as_posix(),
|
116 |
+
language,
|
117 |
+
label
|
118 |
+
])
|
119 |
+
|
120 |
+
# models
|
121 |
+
model_choices = list()
|
122 |
+
for filename in trained_model_dir.glob("*.zip"):
|
123 |
+
model_name = filename.stem
|
124 |
+
model_choices.append(model_name)
|
125 |
+
|
126 |
+
# ui
|
127 |
+
brief_description = """
|
128 |
+
国际语音智能外呼系统, 电话声音分类.
|
129 |
+
"""
|
130 |
+
|
131 |
+
# ui
|
132 |
+
with gr.Blocks() as blocks:
|
133 |
+
gr.Markdown(value=brief_description)
|
134 |
+
|
135 |
+
with gr.Row():
|
136 |
+
with gr.Column(scale=3):
|
137 |
+
c_audio = gr.Audio(label="audio")
|
138 |
+
with gr.Row():
|
139 |
+
with gr.Column(scale=3):
|
140 |
+
c_model_name = gr.Dropdown(choices=model_choices, value=model_choices[0], label="language")
|
141 |
+
with gr.Column(scale=3):
|
142 |
+
c_ground_true = gr.Textbox(label="ground_true")
|
143 |
+
|
144 |
+
c_button = gr.Button("run", variant="primary")
|
145 |
+
with gr.Column(scale=3):
|
146 |
+
c_label = gr.Textbox(label="label")
|
147 |
+
c_probability = gr.Number(label="probability")
|
148 |
+
|
149 |
+
gr.Examples(
|
150 |
+
examples,
|
151 |
+
inputs=[c_audio, c_model_name, c_ground_true],
|
152 |
+
outputs=[c_label, c_probability],
|
153 |
+
fn=click_button,
|
154 |
+
examples_per_page=5,
|
155 |
+
)
|
156 |
+
|
157 |
+
c_button.click(
|
158 |
+
click_button,
|
159 |
+
inputs=[c_audio, c_model_name, c_ground_true],
|
160 |
+
outputs=[c_label, c_probability],
|
161 |
+
)
|
162 |
+
|
163 |
+
blocks.queue().launch(
|
164 |
+
share=False if platform.system() == "Windows" else False,
|
165 |
+
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
|
166 |
+
server_port=args.server_port
|
167 |
+
)
|
168 |
+
return
|
169 |
+
|
170 |
+
|
171 |
+
if __name__ == "__main__":
|
172 |
+
main()
|
project_settings.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from toolbox.os.environment import EnvironmentManager
|
7 |
+
|
8 |
+
|
9 |
+
project_path = os.path.abspath(os.path.dirname(__file__))
|
10 |
+
project_path = Path(project_path)
|
11 |
+
|
12 |
+
environment = EnvironmentManager(
|
13 |
+
path=os.path.join(project_path, "dotenv"),
|
14 |
+
env=os.environ.get("environment", "dev"),
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
if __name__ == '__main__':
|
19 |
+
pass
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.3.0
|
2 |
+
torchaudio==2.3.0
|
3 |
+
fsspec==2024.5.0
|
4 |
+
librosa==0.10.2
|
5 |
+
pandas==2.0.3
|
6 |
+
openpyxl==3.0.9
|
7 |
+
xlrd==1.2.0
|
8 |
+
tqdm==4.66.4
|
9 |
+
overrides==1.9.0
|
10 |
+
pyyaml==6.0.1
|
11 |
+
evaluate==0.4.2
|
12 |
+
gradio
|
script/install_nvidia_driver.sh
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
#GPU驱动安装需要先将原有的显示关闭, 重启机器, 再进行安装.
|
3 |
+
#参考链接:
|
4 |
+
#https://blog.csdn.net/kingschan/article/details/19033595
|
5 |
+
#https://blog.csdn.net/HaixWang/article/details/90408538
|
6 |
+
#
|
7 |
+
#>>> yum install -y pciutils
|
8 |
+
#查看 linux 机器上是否有 GPU
|
9 |
+
#lspci |grep -i nvidia
|
10 |
+
#
|
11 |
+
#>>> lspci |grep -i nvidia
|
12 |
+
#00:08.0 3D controller: NVIDIA Corporation TU104GL [Tesla T4] (rev a1)
|
13 |
+
#
|
14 |
+
#
|
15 |
+
#NVIDIA 驱动程序下载
|
16 |
+
#先在 pytorch 上查看应该用什么 cuda 版本, 再安装对应的 cuda-toolkit cuda.
|
17 |
+
#再根据 gpu 版本下载安装对应的 nvidia 驱动
|
18 |
+
#
|
19 |
+
## pytorch 版本
|
20 |
+
#https://pytorch.org/get-started/locally/
|
21 |
+
#
|
22 |
+
## CUDA 下载 (好像不需要这个)
|
23 |
+
#https://developer.nvidia.com/cuda-toolkit-archive
|
24 |
+
#
|
25 |
+
## nvidia 驱动
|
26 |
+
#https://www.nvidia.cn/Download/index.aspx?lang=cn
|
27 |
+
#http://www.nvidia.com/Download/index.aspx
|
28 |
+
#
|
29 |
+
#在下方的下拉列表中进行选择,针对您的 NVIDIA 产品确定合适的驱动。
|
30 |
+
#产品类型:
|
31 |
+
#Data Center / Tesla
|
32 |
+
#产品系列:
|
33 |
+
#T-Series
|
34 |
+
#产品家族:
|
35 |
+
#Tesla T4
|
36 |
+
#操作系统:
|
37 |
+
#Linux 64-bit
|
38 |
+
#CUDA Toolkit:
|
39 |
+
#10.2
|
40 |
+
#语言:
|
41 |
+
#Chinese (Simpleified)
|
42 |
+
#
|
43 |
+
#
|
44 |
+
#>>> mkdir -p /data/tianxing
|
45 |
+
#>>> cd /data/tianxing
|
46 |
+
#>>> wget https://cn.download.nvidia.com/tesla/440.118.02/NVIDIA-Linux-x86_64-440.118.02.run
|
47 |
+
#>>> sh NVIDIA-Linux-x86_64-440.118.02.run
|
48 |
+
#
|
49 |
+
## 异常:
|
50 |
+
#ERROR: The Nouveau kernel driver is currently in use by your system. This driver is incompatible with the NVIDIA driver, and must be disabled before proceeding. Please consult the NVIDIA driver README and your
|
51 |
+
#Linux distribution's documentation for details on how to correctly disable the Nouveau kernel driver.
|
52 |
+
#[OK]
|
53 |
+
#
|
54 |
+
#For some distributions, Nouveau can be disabled by adding a file in the modprobe configuration directory. Would you like nvidia-installer to attempt to create this modprobe file for you?
|
55 |
+
#[NO]
|
56 |
+
#
|
57 |
+
#ERROR: Installation has failed. Please see the file '/var/log/nvidia-installer.log' for details. You may find suggestions on fixing installation problems in the README available on the Linux driver download
|
58 |
+
#page at www.nvidia.com.
|
59 |
+
#[OK]
|
60 |
+
#
|
61 |
+
## 参考链接:
|
62 |
+
#https://blog.csdn.net/kingschan/article/details/19033595
|
63 |
+
#
|
64 |
+
## 禁用原有的显卡驱动 nouveau
|
65 |
+
#>>> echo -e "blacklist nouveau\noptions nouveau modeset=0\n" > /etc/modprobe.d/blacklist-nouveau.conf
|
66 |
+
#>>> sudo dracut --force
|
67 |
+
## 重启
|
68 |
+
#>>> reboot
|
69 |
+
#
|
70 |
+
#>>> init 3
|
71 |
+
#>>> sh NVIDIA-Linux-x86_64-440.118.02.run
|
72 |
+
#
|
73 |
+
## 异常
|
74 |
+
#ERROR: Unable to find the kernel source tree for the currently running kernel. Please make sure you have installed the kernel source files for your kernel and that they are properly configured; on Red Hat Linux systems, for example, be sure you have the 'kernel-source' or 'kernel-devel' RPM installed. If you know the correct kernel source files are installed, you may specify the kernel source path with the '--kernel-source-path' command line option.
|
75 |
+
#[OK]
|
76 |
+
#ERROR: Installation has failed. Please see the file '/var/log/nvidia-installer.log' for details. You may find suggestions on fixing installation problems in the README available on the Linux driver download
|
77 |
+
#page at www.nvidia.com.
|
78 |
+
#[OK]
|
79 |
+
#
|
80 |
+
## 参考链接
|
81 |
+
## https://blog.csdn.net/HaixWang/article/details/90408538
|
82 |
+
#
|
83 |
+
#>>> uname -r
|
84 |
+
#3.10.0-1160.49.1.el7.x86_64
|
85 |
+
#>>> yum install kernel-devel kernel-headers -y
|
86 |
+
#>>> yum info kernel-devel kernel-headers
|
87 |
+
#>>> yum install -y "kernel-devel-uname-r == $(uname -r)"
|
88 |
+
#>>> yum -y distro-sync
|
89 |
+
#
|
90 |
+
#>>> sh NVIDIA-Linux-x86_64-440.118.02.run
|
91 |
+
#
|
92 |
+
## 安装成功
|
93 |
+
#WARNING: nvidia-installer was forced to guess the X library path '/usr/lib64' and X module path '/usr/lib64/xorg/modules'; these paths were not queryable from the system. If X fails to find the NVIDIA X driver
|
94 |
+
#module, please install the `pkg-config` utility and the X.Org SDK/development package for your distribution and reinstall the driver.
|
95 |
+
#[OK]
|
96 |
+
#Install NVIDIA's 32-bit compatibility libraries?
|
97 |
+
#[YES]
|
98 |
+
#Installation of the kernel module for the NVIDIA Accelerated Graphics Driver for Linux-x86_64 (version 440.118.02) is now complete.
|
99 |
+
#[OK]
|
100 |
+
#
|
101 |
+
#
|
102 |
+
## 查看 GPU 使用情况; watch -n 1 -d nvidia-smi 每1秒刷新一次.
|
103 |
+
#>>> nvidia-smi
|
104 |
+
#Thu Mar 9 12:00:37 2023
|
105 |
+
#+-----------------------------------------------------------------------------+
|
106 |
+
#| NVIDIA-SMI 440.118.02 Driver Version: 440.118.02 CUDA Version: 10.2 |
|
107 |
+
#|-------------------------------+----------------------+----------------------+
|
108 |
+
#| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
|
109 |
+
#| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
|
110 |
+
#|===============================+======================+======================|
|
111 |
+
#| 0 Tesla T4 Off | 00000000:00:08.0 Off | Off |
|
112 |
+
#| N/A 54C P0 22W / 70W | 0MiB / 16127MiB | 0% Default |
|
113 |
+
#+-------------------------------+----------------------+----------------------+
|
114 |
+
#
|
115 |
+
#+-----------------------------------------------------------------------------+
|
116 |
+
#| Processes: GPU Memory |
|
117 |
+
#| GPU PID Type Process name Usage |
|
118 |
+
#|=============================================================================|
|
119 |
+
#| No running processes found |
|
120 |
+
#+-----------------------------------------------------------------------------+
|
121 |
+
#
|
122 |
+
#
|
123 |
+
|
124 |
+
# params
|
125 |
+
stage=1
|
126 |
+
nvidia_driver_filename=https://cn.download.nvidia.com/tesla/440.118.02/NVIDIA-Linux-x86_64-440.118.02.run
|
127 |
+
|
128 |
+
# parse options
|
129 |
+
while true; do
|
130 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
131 |
+
case "$1" in
|
132 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
133 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
134 |
+
old_value="(eval echo \\$$name)";
|
135 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
136 |
+
was_bool=true;
|
137 |
+
else
|
138 |
+
was_bool=false;
|
139 |
+
fi
|
140 |
+
|
141 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
142 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
143 |
+
eval "${name}=\"$2\"";
|
144 |
+
|
145 |
+
# Check that Boolean-valued arguments are really Boolean.
|
146 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
147 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
148 |
+
exit 1;
|
149 |
+
fi
|
150 |
+
shift 2;
|
151 |
+
;;
|
152 |
+
|
153 |
+
*) break;
|
154 |
+
esac
|
155 |
+
done
|
156 |
+
|
157 |
+
echo "stage: ${stage}";
|
158 |
+
|
159 |
+
yum -y install wget
|
160 |
+
yum -y install sudo
|
161 |
+
|
162 |
+
if [ ${stage} -eq 0 ]; then
|
163 |
+
mkdir -p /data/dep
|
164 |
+
cd /data/dep || echo 1;
|
165 |
+
wget -P /data/dep ${nvidia_driver_filename}
|
166 |
+
|
167 |
+
echo -e "blacklist nouveau\noptions nouveau modeset=0\n" > /etc/modprobe.d/blacklist-nouveau.conf
|
168 |
+
sudo dracut --force
|
169 |
+
# 重启
|
170 |
+
reboot
|
171 |
+
elif [ ${stage} -eq 1 ]; then
|
172 |
+
init 3
|
173 |
+
|
174 |
+
yum install -y kernel-devel kernel-headers
|
175 |
+
yum info kernel-devel kernel-headers
|
176 |
+
yum install -y "kernel-devel-uname-r == $(uname -r)"
|
177 |
+
yum -y distro-sync
|
178 |
+
|
179 |
+
cd /data/dep || echo 1;
|
180 |
+
|
181 |
+
# 安装时, 需要回车三下.
|
182 |
+
sh NVIDIA-Linux-x86_64-440.118.02.run
|
183 |
+
nvidia-smi
|
184 |
+
fi
|
script/install_python.sh
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# 参数:
|
4 |
+
python_version="3.6.5";
|
5 |
+
system_version="centos";
|
6 |
+
|
7 |
+
|
8 |
+
# parse options
|
9 |
+
while true; do
|
10 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
11 |
+
case "$1" in
|
12 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
13 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
14 |
+
old_value="(eval echo \\$$name)";
|
15 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
16 |
+
was_bool=true;
|
17 |
+
else
|
18 |
+
was_bool=false;
|
19 |
+
fi
|
20 |
+
|
21 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
22 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
23 |
+
eval "${name}=\"$2\"";
|
24 |
+
|
25 |
+
# Check that Boolean-valued arguments are really Boolean.
|
26 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
27 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
28 |
+
exit 1;
|
29 |
+
fi
|
30 |
+
shift 2;
|
31 |
+
;;
|
32 |
+
|
33 |
+
*) break;
|
34 |
+
esac
|
35 |
+
done
|
36 |
+
|
37 |
+
echo "python_version: ${python_version}";
|
38 |
+
echo "system_version: ${system_version}";
|
39 |
+
|
40 |
+
|
41 |
+
if [ ${system_version} = "centos" ]; then
|
42 |
+
# 安装 python 开发编译环境
|
43 |
+
yum -y groupinstall "Development tools"
|
44 |
+
yum -y install zlib-devel bzip2-devel openssl-devel ncurses-devel sqlite-devel readline-devel tk-devel gdbm-devel db4-devel libpcap-devel xz-devel
|
45 |
+
yum install libffi-devel -y
|
46 |
+
yum install -y wget
|
47 |
+
yum install -y make
|
48 |
+
|
49 |
+
mkdir -p /data/dep
|
50 |
+
cd /data/dep || exit 1;
|
51 |
+
if [ ! -e Python-${python_version}.tgz ]; then
|
52 |
+
wget -P /data/dep https://www.python.org/ftp/python/${python_version}/Python-${python_version}.tgz
|
53 |
+
fi
|
54 |
+
|
55 |
+
cd /data/dep || exit 1;
|
56 |
+
if [ ! -d Python-${python_version} ]; then
|
57 |
+
tar -zxvf Python-${python_version}.tgz
|
58 |
+
cd /data/dep/Python-${python_version} || exit 1;
|
59 |
+
fi
|
60 |
+
|
61 |
+
mkdir /usr/local/python-${python_version}
|
62 |
+
./configure --prefix=/usr/local/python-${python_version}
|
63 |
+
make && make install
|
64 |
+
|
65 |
+
/usr/local/python-${python_version}/bin/python3 -V
|
66 |
+
/usr/local/python-${python_version}/bin/pip3 -V
|
67 |
+
|
68 |
+
rm -rf /usr/local/bin/python3
|
69 |
+
rm -rf /usr/local/bin/pip3
|
70 |
+
ln -s /usr/local/python-${python_version}/bin/python3 /usr/local/bin/python3
|
71 |
+
ln -s /usr/local/python-${python_version}/bin/pip3 /usr/local/bin/pip3
|
72 |
+
|
73 |
+
python3 -V
|
74 |
+
pip3 -V
|
75 |
+
|
76 |
+
elif [ ${system_version} = "ubuntu" ]; then
|
77 |
+
# 安装 python 开发编译环境
|
78 |
+
# https://zhuanlan.zhihu.com/p/506491209
|
79 |
+
|
80 |
+
# 刷新软件包目录
|
81 |
+
sudo apt update
|
82 |
+
# 列出当前可用的更新
|
83 |
+
sudo apt list --upgradable
|
84 |
+
# 如上一步提示有可以更新的项目,则执行更新
|
85 |
+
sudo apt -y upgrade
|
86 |
+
# 安装 GCC 编译器
|
87 |
+
sudo apt install gcc
|
88 |
+
# 检查安装是否成功
|
89 |
+
gcc -v
|
90 |
+
|
91 |
+
# 安装依赖
|
92 |
+
sudo apt install -y build-essential zlib1g-dev libncurses5-dev libgdbm-dev libnss3-dev libssl-dev libreadline-dev libffi-dev libbz2-dev liblzma-dev sqlite3 libsqlite3-dev tk-dev uuid-dev libgdbm-compat-dev
|
93 |
+
|
94 |
+
mkdir -p /data/dep
|
95 |
+
cd /data/dep || exit 1;
|
96 |
+
if [ ! -e Python-${python_version}.tgz ]; then
|
97 |
+
# sudo wget -P /data/dep https://www.python.org/ftp/python/3.6.5/Python-3.6.5.tgz
|
98 |
+
sudo wget -P /data/dep https://www.python.org/ftp/python/${python_version}/Python-${python_version}.tgz
|
99 |
+
fi
|
100 |
+
|
101 |
+
cd /data/dep || exit 1;
|
102 |
+
if [ ! -d Python-${python_version} ]; then
|
103 |
+
# tar -zxvf Python-3.6.5.tgz
|
104 |
+
tar -zxvf Python-${python_version}.tgz
|
105 |
+
# cd /data/dep/Python-3.6.5
|
106 |
+
cd /data/dep/Python-${python_version} || exit 1;
|
107 |
+
fi
|
108 |
+
|
109 |
+
# mkdir /usr/local/python-3.6.5
|
110 |
+
mkdir /usr/local/python-${python_version}
|
111 |
+
|
112 |
+
# 检查依赖与配置编译
|
113 |
+
# sudo ./configure --prefix=/usr/local/python-3.6.5 --enable-optimizations --with-lto --enable-shared
|
114 |
+
sudo ./configure --prefix=/usr/local/python-${python_version} --enable-optimizations --with-lto --enable-shared
|
115 |
+
cpu_count=$(cat /proc/cpuinfo | grep processor | wc -l)
|
116 |
+
# sudo make -j 4
|
117 |
+
sudo make -j "${cpu_count}"
|
118 |
+
|
119 |
+
/usr/local/python-${python_version}/bin/python3 -V
|
120 |
+
/usr/local/python-${python_version}/bin/pip3 -V
|
121 |
+
|
122 |
+
rm -rf /usr/local/bin/python3
|
123 |
+
rm -rf /usr/local/bin/pip3
|
124 |
+
ln -s /usr/local/python-${python_version}/bin/python3 /usr/local/bin/python3
|
125 |
+
ln -s /usr/local/python-${python_version}/bin/pip3 /usr/local/bin/pip3
|
126 |
+
|
127 |
+
python3 -V
|
128 |
+
pip3 -V
|
129 |
+
fi
|
toolbox/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
if __name__ == '__main__':
|
5 |
+
pass
|
toolbox/json/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
pass
|
toolbox/json/misc.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from typing import Callable
|
4 |
+
|
5 |
+
|
6 |
+
def traverse(js, callback: Callable, *args, **kwargs):
|
7 |
+
if isinstance(js, list):
|
8 |
+
result = list()
|
9 |
+
for l in js:
|
10 |
+
l = traverse(l, callback, *args, **kwargs)
|
11 |
+
result.append(l)
|
12 |
+
return result
|
13 |
+
elif isinstance(js, tuple):
|
14 |
+
result = list()
|
15 |
+
for l in js:
|
16 |
+
l = traverse(l, callback, *args, **kwargs)
|
17 |
+
result.append(l)
|
18 |
+
return tuple(result)
|
19 |
+
elif isinstance(js, dict):
|
20 |
+
result = dict()
|
21 |
+
for k, v in js.items():
|
22 |
+
k = traverse(k, callback, *args, **kwargs)
|
23 |
+
v = traverse(v, callback, *args, **kwargs)
|
24 |
+
result[k] = v
|
25 |
+
return result
|
26 |
+
elif isinstance(js, int):
|
27 |
+
return callback(js, *args, **kwargs)
|
28 |
+
elif isinstance(js, str):
|
29 |
+
return callback(js, *args, **kwargs)
|
30 |
+
else:
|
31 |
+
return js
|
32 |
+
|
33 |
+
|
34 |
+
def demo1():
|
35 |
+
d = {
|
36 |
+
"env": "ppe",
|
37 |
+
"mysql_connect": {
|
38 |
+
"host": "$mysql_connect_host",
|
39 |
+
"port": 3306,
|
40 |
+
"user": "callbot",
|
41 |
+
"password": "NxcloudAI2021!",
|
42 |
+
"database": "callbot_ppe",
|
43 |
+
"charset": "utf8"
|
44 |
+
},
|
45 |
+
"es_connect": {
|
46 |
+
"hosts": ["10.20.251.8"],
|
47 |
+
"http_auth": ["elastic", "ElasticAI2021!"],
|
48 |
+
"port": 9200
|
49 |
+
}
|
50 |
+
}
|
51 |
+
|
52 |
+
def callback(s):
|
53 |
+
if isinstance(s, str) and s.startswith('$'):
|
54 |
+
return s[1:]
|
55 |
+
return s
|
56 |
+
|
57 |
+
result = traverse(d, callback=callback)
|
58 |
+
print(result)
|
59 |
+
return
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == '__main__':
|
63 |
+
demo1()
|
toolbox/os/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
pass
|
toolbox/os/command.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class Command(object):
|
7 |
+
custom_command = [
|
8 |
+
"cd"
|
9 |
+
]
|
10 |
+
|
11 |
+
@staticmethod
|
12 |
+
def _get_cmd(command):
|
13 |
+
command = str(command).strip()
|
14 |
+
if command == "":
|
15 |
+
return None
|
16 |
+
cmd_and_args = command.split(sep=" ")
|
17 |
+
cmd = cmd_and_args[0]
|
18 |
+
args = " ".join(cmd_and_args[1:])
|
19 |
+
return cmd, args
|
20 |
+
|
21 |
+
@classmethod
|
22 |
+
def popen(cls, command):
|
23 |
+
cmd, args = cls._get_cmd(command)
|
24 |
+
if cmd in cls.custom_command:
|
25 |
+
method = getattr(cls, cmd)
|
26 |
+
return method(args)
|
27 |
+
else:
|
28 |
+
resp = os.popen(command)
|
29 |
+
result = resp.read()
|
30 |
+
resp.close()
|
31 |
+
return result
|
32 |
+
|
33 |
+
@classmethod
|
34 |
+
def cd(cls, args):
|
35 |
+
if args.startswith("/"):
|
36 |
+
os.chdir(args)
|
37 |
+
else:
|
38 |
+
pwd = os.getcwd()
|
39 |
+
path = os.path.join(pwd, args)
|
40 |
+
os.chdir(path)
|
41 |
+
|
42 |
+
@classmethod
|
43 |
+
def system(cls, command):
|
44 |
+
return os.system(command)
|
45 |
+
|
46 |
+
def __init__(self):
|
47 |
+
pass
|
48 |
+
|
49 |
+
|
50 |
+
def ps_ef_grep(keyword: str):
|
51 |
+
cmd = "ps -ef | grep {}".format(keyword)
|
52 |
+
rows = Command.popen(cmd)
|
53 |
+
rows = str(rows).split("\n")
|
54 |
+
rows = [row for row in rows if row.__contains__(keyword) and not row.__contains__("grep")]
|
55 |
+
return rows
|
56 |
+
|
57 |
+
|
58 |
+
if __name__ == "__main__":
|
59 |
+
pass
|
toolbox/os/environment.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
from dotenv.main import DotEnv
|
8 |
+
|
9 |
+
from toolbox.json.misc import traverse
|
10 |
+
|
11 |
+
|
12 |
+
class EnvironmentManager(object):
|
13 |
+
def __init__(self, path, env, override=False):
|
14 |
+
filename = os.path.join(path, '{}.env'.format(env))
|
15 |
+
self.filename = filename
|
16 |
+
|
17 |
+
load_dotenv(
|
18 |
+
dotenv_path=filename,
|
19 |
+
override=override
|
20 |
+
)
|
21 |
+
|
22 |
+
self._environ = dict()
|
23 |
+
|
24 |
+
def open_dotenv(self, filename: str = None):
|
25 |
+
filename = filename or self.filename
|
26 |
+
dotenv = DotEnv(
|
27 |
+
dotenv_path=filename,
|
28 |
+
stream=None,
|
29 |
+
verbose=False,
|
30 |
+
interpolate=False,
|
31 |
+
override=False,
|
32 |
+
encoding="utf-8",
|
33 |
+
)
|
34 |
+
result = dotenv.dict()
|
35 |
+
return result
|
36 |
+
|
37 |
+
def get(self, key, default=None, dtype=str):
|
38 |
+
result = os.environ.get(key)
|
39 |
+
if result is None:
|
40 |
+
if default is None:
|
41 |
+
result = None
|
42 |
+
else:
|
43 |
+
result = default
|
44 |
+
else:
|
45 |
+
result = dtype(result)
|
46 |
+
self._environ[key] = result
|
47 |
+
return result
|
48 |
+
|
49 |
+
|
50 |
+
_DEFAULT_DTYPE_MAP = {
|
51 |
+
'int': int,
|
52 |
+
'float': float,
|
53 |
+
'str': str,
|
54 |
+
'json.loads': json.loads
|
55 |
+
}
|
56 |
+
|
57 |
+
|
58 |
+
class JsonConfig(object):
|
59 |
+
"""
|
60 |
+
将 json 中, 形如 `$float:threshold` 的值, 处理为:
|
61 |
+
从环境变量中查到 threshold, 再将其转换为 float 类型.
|
62 |
+
"""
|
63 |
+
def __init__(self, dtype_map: dict = None, environment: EnvironmentManager = None):
|
64 |
+
self.dtype_map = dtype_map or _DEFAULT_DTYPE_MAP
|
65 |
+
self.environment = environment or os.environ
|
66 |
+
|
67 |
+
def sanitize_by_filename(self, filename: str):
|
68 |
+
with open(filename, 'r', encoding='utf-8') as f:
|
69 |
+
js = json.load(f)
|
70 |
+
|
71 |
+
return self.sanitize_by_json(js)
|
72 |
+
|
73 |
+
def sanitize_by_json(self, js):
|
74 |
+
js = traverse(
|
75 |
+
js,
|
76 |
+
callback=self.sanitize,
|
77 |
+
environment=self.environment
|
78 |
+
)
|
79 |
+
return js
|
80 |
+
|
81 |
+
def sanitize(self, string, environment):
|
82 |
+
"""支持 $ 符开始的, 环境变量配置"""
|
83 |
+
if isinstance(string, str) and string.startswith('$'):
|
84 |
+
dtype, key = string[1:].split(':')
|
85 |
+
dtype = self.dtype_map[dtype]
|
86 |
+
|
87 |
+
value = environment.get(key)
|
88 |
+
if value is None:
|
89 |
+
raise AssertionError('environment not exist. key: {}'.format(key))
|
90 |
+
|
91 |
+
value = dtype(value)
|
92 |
+
result = value
|
93 |
+
else:
|
94 |
+
result = string
|
95 |
+
return result
|
96 |
+
|
97 |
+
|
98 |
+
def demo1():
|
99 |
+
import json
|
100 |
+
|
101 |
+
from project_settings import project_path
|
102 |
+
|
103 |
+
environment = EnvironmentManager(
|
104 |
+
path=os.path.join(project_path, 'server/callbot_server/dotenv'),
|
105 |
+
env='dev',
|
106 |
+
)
|
107 |
+
init_scenes = environment.get(key='init_scenes', dtype=json.loads)
|
108 |
+
print(init_scenes)
|
109 |
+
print(environment._environ)
|
110 |
+
return
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == '__main__':
|
114 |
+
demo1()
|
toolbox/os/other.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import inspect
|
3 |
+
|
4 |
+
|
5 |
+
def pwd():
|
6 |
+
"""你在哪个文件调用此函数, 它就会返回那个文件所在的 dir 目标"""
|
7 |
+
frame = inspect.stack()[1]
|
8 |
+
module = inspect.getmodule(frame[0])
|
9 |
+
return os.path.dirname(os.path.abspath(module.__file__))
|
toolbox/torch/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
if __name__ == '__main__':
|
5 |
+
pass
|
toolbox/torch/modules/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
pass
|
toolbox/torch/modules/gaussian_mixture.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://github.com/georgepar/gmmhmm-pytorch/blob/master/gmm.py
|
5 |
+
https://github.com/ldeecke/gmm-torch
|
6 |
+
"""
|
7 |
+
import math
|
8 |
+
|
9 |
+
from sklearn import cluster
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
|
14 |
+
class GaussianMixtureModel(nn.Module):
|
15 |
+
def __init__(self,
|
16 |
+
n_mixtures: int,
|
17 |
+
n_features: int,
|
18 |
+
init: str = "random",
|
19 |
+
device: str = 'cpu',
|
20 |
+
n_iter: int = 1000,
|
21 |
+
delta: float = 1e-3,
|
22 |
+
warm_start: bool = False,
|
23 |
+
):
|
24 |
+
super(GaussianMixtureModel, self).__init__()
|
25 |
+
self.n_mixtures = n_mixtures
|
26 |
+
self.n_features = n_features
|
27 |
+
self.init = init
|
28 |
+
self.device = device
|
29 |
+
self.n_iter = n_iter
|
30 |
+
self.delta = delta
|
31 |
+
self.warm_start = warm_start
|
32 |
+
|
33 |
+
if init not in ('kmeans', 'random'):
|
34 |
+
raise AssertionError
|
35 |
+
|
36 |
+
self.mu = nn.Parameter(
|
37 |
+
torch.Tensor(n_mixtures, n_features),
|
38 |
+
requires_grad=False,
|
39 |
+
)
|
40 |
+
|
41 |
+
self.sigma = None
|
42 |
+
|
43 |
+
# the weight of each gaussian
|
44 |
+
self.pi = nn.Parameter(
|
45 |
+
torch.Tensor(n_mixtures),
|
46 |
+
requires_grad=False
|
47 |
+
)
|
48 |
+
|
49 |
+
self.converged_ = False
|
50 |
+
self.eps = 1e-6
|
51 |
+
self.delta = delta
|
52 |
+
self.warm_start = warm_start
|
53 |
+
self.n_iter = n_iter
|
54 |
+
|
55 |
+
def reset_sigma(self):
|
56 |
+
raise NotImplementedError
|
57 |
+
|
58 |
+
def estimate_precisions(self):
|
59 |
+
raise NotImplementedError
|
60 |
+
|
61 |
+
def log_prob(self, x):
|
62 |
+
raise NotImplementedError
|
63 |
+
|
64 |
+
def weighted_log_prob(self, x):
|
65 |
+
log_prob = self.log_prob(x)
|
66 |
+
weighted_log_prob = log_prob + torch.log(self.pi)
|
67 |
+
return weighted_log_prob
|
68 |
+
|
69 |
+
def log_likelihood(self, x):
|
70 |
+
weighted_log_prob = self.weighted_log_prob(x)
|
71 |
+
per_sample_log_likelihood = torch.logsumexp(weighted_log_prob, dim=1)
|
72 |
+
log_likelihood = torch.sum(per_sample_log_likelihood)
|
73 |
+
return log_likelihood
|
74 |
+
|
75 |
+
def e_step(self, x):
|
76 |
+
weighted_log_prob = self.weighted_log_prob(x)
|
77 |
+
weighted_log_prob = weighted_log_prob.unsqueeze(dim=-1)
|
78 |
+
log_likelihood = torch.logsumexp(weighted_log_prob, dim=1, keepdim=True)
|
79 |
+
q = weighted_log_prob - log_likelihood
|
80 |
+
return q.squeeze()
|
81 |
+
|
82 |
+
def m_step(self, x, q):
|
83 |
+
x = x.unsqueeze(dim=1)
|
84 |
+
|
85 |
+
return
|
86 |
+
|
87 |
+
def estimate_mu(self, x, pi, responsibilities):
|
88 |
+
nk = pi * x.size(0)
|
89 |
+
mu = torch.sum(responsibilities * x, dim=0, keepdim=True) / nk
|
90 |
+
return mu
|
91 |
+
|
92 |
+
def estimate_pi(self, x, responsibilities):
|
93 |
+
pi = torch.sum(responsibilities, dim=0, keepdim=True) + self.eps
|
94 |
+
pi = pi / x.size(0)
|
95 |
+
return pi
|
96 |
+
|
97 |
+
def reset_parameters(self, x=None):
|
98 |
+
if self.init == 'random' or x is None:
|
99 |
+
self.mu.normal_()
|
100 |
+
self.reset_sigma()
|
101 |
+
self.pi.fill_(1.0 / self.n_mixtures)
|
102 |
+
elif self.init == 'kmeans':
|
103 |
+
centroids = cluster.KMeans(n_clusters=self.n_mixtures, n_init=1).fit(x).cluster_centers_
|
104 |
+
centroids = torch.tensor(centroids).to(self.device)
|
105 |
+
self.update_(mu=centroids)
|
106 |
+
else:
|
107 |
+
raise NotImplementedError
|
108 |
+
|
109 |
+
|
110 |
+
class DiagonalCovarianceGMM(GaussianMixtureModel):
|
111 |
+
def __init__(self,
|
112 |
+
n_mixtures: int,
|
113 |
+
n_features: int,
|
114 |
+
init: str = "random",
|
115 |
+
device: str = 'cpu',
|
116 |
+
n_iter: int = 1000,
|
117 |
+
delta: float = 1e-3,
|
118 |
+
warm_start: bool = False,
|
119 |
+
):
|
120 |
+
super(DiagonalCovarianceGMM, self).__init__(
|
121 |
+
n_mixtures=n_mixtures,
|
122 |
+
n_features=n_features,
|
123 |
+
init=init,
|
124 |
+
device=device,
|
125 |
+
n_iter=n_iter,
|
126 |
+
delta=delta,
|
127 |
+
warm_start=warm_start,
|
128 |
+
)
|
129 |
+
self.sigma = nn.Parameter(
|
130 |
+
torch.Tensor(n_mixtures, n_features), requires_grad=False
|
131 |
+
)
|
132 |
+
self.reset_parameters()
|
133 |
+
self.to(self.device)
|
134 |
+
|
135 |
+
def reset_sigma(self):
|
136 |
+
self.sigma.fill_(1)
|
137 |
+
|
138 |
+
def estimate_precisions(self):
|
139 |
+
return torch.rsqrt(self.sigma)
|
140 |
+
|
141 |
+
def log_prob(self, x):
|
142 |
+
precisions = self.estimate_precisions()
|
143 |
+
|
144 |
+
x = x.unsqueeze(1)
|
145 |
+
mu = self.mu.unsqueeze(0)
|
146 |
+
precisions = precisions.unsqueeze(0)
|
147 |
+
|
148 |
+
# This is outer product
|
149 |
+
exp_term = torch.sum(
|
150 |
+
(mu * mu + x * x - 2 * x * mu) * (precisions ** 2), dim=2, keepdim=True
|
151 |
+
)
|
152 |
+
log_det = torch.sum(torch.log(precisions), dim=2, keepdim=True)
|
153 |
+
|
154 |
+
logp = -0.5 * (self.n_features * torch.log(2 * math.pi) + exp_term) + log_det
|
155 |
+
|
156 |
+
return logp.squeeze()
|
157 |
+
|
158 |
+
def estimate_sigma(self, x, mu, pi, responsibilities):
|
159 |
+
nk = pi * x.size(0)
|
160 |
+
x2 = (responsibilities * x * x).sum(0, keepdim=True) / nk
|
161 |
+
mu2 = mu * mu
|
162 |
+
xmu = (responsibilities * mu * x).sum(0, keepdim=True) / nk
|
163 |
+
sigma = x2 - 2 * xmu + mu2 + self.eps
|
164 |
+
|
165 |
+
return sigma
|
166 |
+
|
167 |
+
|
168 |
+
def demo1():
|
169 |
+
return
|
170 |
+
|
171 |
+
|
172 |
+
if __name__ == '__main__':
|
173 |
+
demo1()
|
toolbox/torch/modules/highway.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class Highway(nn.Module):
|
7 |
+
"""
|
8 |
+
https://arxiv.org/abs/1505.00387
|
9 |
+
[Submitted on 3 May 2015 (v1), last revised 3 Nov 2015 (this version, v2)]
|
10 |
+
|
11 |
+
discuss of Highway and ResNet
|
12 |
+
https://www.zhihu.com/question/279426970
|
13 |
+
"""
|
14 |
+
def __init__(self, in_size, out_size):
|
15 |
+
super(Highway, self).__init__()
|
16 |
+
self.H = nn.Linear(in_size, out_size)
|
17 |
+
self.H.bias.data.zero_()
|
18 |
+
self.T = nn.Linear(in_size, out_size)
|
19 |
+
self.T.bias.data.fill_(-1)
|
20 |
+
self.relu = nn.ReLU()
|
21 |
+
self.sigmoid = nn.Sigmoid()
|
22 |
+
|
23 |
+
def forward(self, inputs):
|
24 |
+
H = self.relu(self.H(inputs))
|
25 |
+
T = self.sigmoid(self.T(inputs))
|
26 |
+
return H * T + inputs * (1.0 - T)
|
27 |
+
|
28 |
+
|
29 |
+
if __name__ == '__main__':
|
30 |
+
pass
|
toolbox/torch/modules/loss.py
ADDED
@@ -0,0 +1,738 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import math
|
4 |
+
from typing import List, Optional
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch.nn.modules.loss import _Loss
|
11 |
+
from torch.autograd import Variable
|
12 |
+
|
13 |
+
|
14 |
+
class ClassBalancedLoss(_Loss):
|
15 |
+
"""
|
16 |
+
https://arxiv.org/abs/1901.05555
|
17 |
+
"""
|
18 |
+
@staticmethod
|
19 |
+
def demo1():
|
20 |
+
batch_loss: torch.FloatTensor = torch.randn(size=(2, 1), dtype=torch.float32)
|
21 |
+
targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
|
22 |
+
|
23 |
+
class_balanced_loss = ClassBalancedLoss(
|
24 |
+
num_classes=3,
|
25 |
+
num_samples_each_class=[300, 433, 50],
|
26 |
+
reduction='mean',
|
27 |
+
)
|
28 |
+
loss = class_balanced_loss.forward(batch_loss=batch_loss, targets=targets)
|
29 |
+
print(loss)
|
30 |
+
return
|
31 |
+
|
32 |
+
@staticmethod
|
33 |
+
def demo2():
|
34 |
+
inputs: torch.FloatTensor = torch.randn(size=(2, 3), dtype=torch.float32)
|
35 |
+
targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
|
36 |
+
|
37 |
+
focal_loss = FocalLoss(
|
38 |
+
num_classes=3,
|
39 |
+
# reduction='mean',
|
40 |
+
# reduction='sum',
|
41 |
+
reduction='none',
|
42 |
+
)
|
43 |
+
batch_loss = focal_loss.forward(inputs, targets)
|
44 |
+
print(batch_loss)
|
45 |
+
|
46 |
+
class_balanced_loss = ClassBalancedLoss(
|
47 |
+
num_classes=3,
|
48 |
+
num_samples_each_class=[300, 433, 50],
|
49 |
+
reduction='mean',
|
50 |
+
)
|
51 |
+
loss = class_balanced_loss.forward(batch_loss=batch_loss, targets=targets)
|
52 |
+
print(loss)
|
53 |
+
|
54 |
+
return
|
55 |
+
|
56 |
+
def __init__(self,
|
57 |
+
num_classes: int,
|
58 |
+
num_samples_each_class: List[int],
|
59 |
+
beta: float = 0.999,
|
60 |
+
reduction: str = 'mean') -> None:
|
61 |
+
super(ClassBalancedLoss, self).__init__(None, None, reduction)
|
62 |
+
|
63 |
+
effective_num = 1.0 - np.power(beta, num_samples_each_class)
|
64 |
+
weights = (1.0 - beta) / np.array(effective_num)
|
65 |
+
self.weights = weights / np.sum(weights) * num_classes
|
66 |
+
|
67 |
+
def forward(self, batch_loss: torch.FloatTensor, targets: torch.LongTensor):
|
68 |
+
"""
|
69 |
+
:param batch_loss: shape=[batch_size, 1]
|
70 |
+
:param targets: shape=[batch_size,]
|
71 |
+
:return:
|
72 |
+
"""
|
73 |
+
weights = list()
|
74 |
+
targets = targets.numpy()
|
75 |
+
for target in targets:
|
76 |
+
weights.append([self.weights[target]])
|
77 |
+
|
78 |
+
weights = torch.tensor(weights, dtype=torch.float32)
|
79 |
+
batch_loss = weights * batch_loss
|
80 |
+
|
81 |
+
if self.reduction == 'mean':
|
82 |
+
loss = batch_loss.mean()
|
83 |
+
elif self.reduction == 'sum':
|
84 |
+
loss = batch_loss.sum()
|
85 |
+
else:
|
86 |
+
loss = batch_loss
|
87 |
+
return loss
|
88 |
+
|
89 |
+
|
90 |
+
class EqualizationLoss(_Loss):
|
91 |
+
"""
|
92 |
+
在图像识别中的, sigmoid 的多标签分类, 且 num_classes 类别数之外有一个 background 背景类别.
|
93 |
+
Equalization Loss
|
94 |
+
https://arxiv.org/abs/2003.05176
|
95 |
+
Equalization Loss v2
|
96 |
+
https://arxiv.org/abs/2012.08548
|
97 |
+
"""
|
98 |
+
|
99 |
+
@staticmethod
|
100 |
+
def demo1():
|
101 |
+
logits: torch.FloatTensor = torch.randn(size=(3, 3), dtype=torch.float32)
|
102 |
+
targets: torch.LongTensor = torch.tensor([1, 2, 3], dtype=torch.long)
|
103 |
+
|
104 |
+
equalization_loss = EqualizationLoss(
|
105 |
+
num_samples_each_class=[300, 433, 50],
|
106 |
+
threshold=100,
|
107 |
+
reduction='mean',
|
108 |
+
)
|
109 |
+
loss = equalization_loss.forward(logits=logits, targets=targets)
|
110 |
+
print(loss)
|
111 |
+
return
|
112 |
+
|
113 |
+
def __init__(self,
|
114 |
+
num_samples_each_class: List[int],
|
115 |
+
threshold: int = 100,
|
116 |
+
reduction: str = 'mean') -> None:
|
117 |
+
super(EqualizationLoss, self).__init__(None, None, reduction)
|
118 |
+
self.num_samples_each_class = np.array(num_samples_each_class, dtype=np.int32)
|
119 |
+
self.threshold = threshold
|
120 |
+
|
121 |
+
def forward(self,
|
122 |
+
logits: torch.FloatTensor,
|
123 |
+
targets: torch.LongTensor
|
124 |
+
):
|
125 |
+
"""
|
126 |
+
num_classes + 1 对应于背景类别 background.
|
127 |
+
:param logits: shape=[batch_size, num_classes]
|
128 |
+
:param targets: shape=[batch_size]
|
129 |
+
:return:
|
130 |
+
"""
|
131 |
+
batch_size, num_classes = logits.size()
|
132 |
+
|
133 |
+
one_hot_targets = F.one_hot(targets, num_classes=num_classes + 1)
|
134 |
+
one_hot_targets = one_hot_targets[:, :-1]
|
135 |
+
|
136 |
+
exclude = self.exclude_func(
|
137 |
+
num_classes=num_classes,
|
138 |
+
targets=targets
|
139 |
+
)
|
140 |
+
is_tail = self.threshold_func(
|
141 |
+
num_classes=num_classes,
|
142 |
+
num_samples_each_class=self.num_samples_each_class,
|
143 |
+
threshold=self.threshold,
|
144 |
+
)
|
145 |
+
|
146 |
+
weights = 1 - exclude * is_tail * (1 - one_hot_targets)
|
147 |
+
|
148 |
+
batch_loss = F.binary_cross_entropy_with_logits(
|
149 |
+
logits,
|
150 |
+
one_hot_targets.float(),
|
151 |
+
reduction='none'
|
152 |
+
)
|
153 |
+
|
154 |
+
batch_loss = weights * batch_loss
|
155 |
+
|
156 |
+
if self.reduction == 'mean':
|
157 |
+
loss = batch_loss.mean()
|
158 |
+
elif self.reduction == 'sum':
|
159 |
+
loss = batch_loss.sum()
|
160 |
+
else:
|
161 |
+
loss = batch_loss
|
162 |
+
|
163 |
+
loss = loss / num_classes
|
164 |
+
return loss
|
165 |
+
|
166 |
+
@staticmethod
|
167 |
+
def exclude_func(num_classes: int, targets: torch.LongTensor):
|
168 |
+
"""
|
169 |
+
最后一个类别是背景 background.
|
170 |
+
:param num_classes: int,
|
171 |
+
:param targets: shape=[batch_size,]
|
172 |
+
:return: weight, shape=[batch_size, num_classes]
|
173 |
+
"""
|
174 |
+
batch_size = targets.shape[0]
|
175 |
+
weight = (targets != num_classes).float()
|
176 |
+
weight = weight.view(batch_size, 1).expand(batch_size, num_classes)
|
177 |
+
return weight
|
178 |
+
|
179 |
+
@staticmethod
|
180 |
+
def threshold_func(num_classes: int, num_samples_each_class: np.ndarray, threshold: int):
|
181 |
+
"""
|
182 |
+
:param num_classes: int,
|
183 |
+
:param num_samples_each_class: shape=[num_classes]
|
184 |
+
:param threshold: int,
|
185 |
+
:return: weight, shape=[1, num_classes]
|
186 |
+
"""
|
187 |
+
weight = torch.zeros(size=(num_classes,))
|
188 |
+
weight[num_samples_each_class < threshold] = 1
|
189 |
+
weight = torch.unsqueeze(weight, dim=0)
|
190 |
+
return weight
|
191 |
+
|
192 |
+
|
193 |
+
class FocalLoss(_Loss):
|
194 |
+
"""
|
195 |
+
https://arxiv.org/abs/1708.02002
|
196 |
+
"""
|
197 |
+
@staticmethod
|
198 |
+
def demo1(self):
|
199 |
+
inputs: torch.FloatTensor = torch.randn(size=(2, 3), dtype=torch.float32)
|
200 |
+
targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
|
201 |
+
|
202 |
+
focal_loss = FocalLoss(
|
203 |
+
num_classes=3,
|
204 |
+
reduction='mean',
|
205 |
+
# reduction='sum',
|
206 |
+
# reduction='none',
|
207 |
+
)
|
208 |
+
loss = focal_loss.forward(inputs, targets)
|
209 |
+
print(loss)
|
210 |
+
return
|
211 |
+
|
212 |
+
def __init__(self,
|
213 |
+
num_classes: int,
|
214 |
+
alpha: List[float] = None,
|
215 |
+
gamma: int = 2,
|
216 |
+
reduction: str = 'mean',
|
217 |
+
inputs_logits: bool = True) -> None:
|
218 |
+
"""
|
219 |
+
:param num_classes:
|
220 |
+
:param alpha:
|
221 |
+
:param gamma:
|
222 |
+
:param reduction: (`none`, `mean`, `sum`) available.
|
223 |
+
:param inputs_logits: if False, the inputs should be probs.
|
224 |
+
"""
|
225 |
+
super(FocalLoss, self).__init__(None, None, reduction)
|
226 |
+
if alpha is None:
|
227 |
+
self.alpha = torch.ones(num_classes, 1)
|
228 |
+
else:
|
229 |
+
self.alpha = torch.tensor(alpha, dtype=torch.float32)
|
230 |
+
self.gamma = gamma
|
231 |
+
self.num_classes = num_classes
|
232 |
+
self.inputs_logits = inputs_logits
|
233 |
+
|
234 |
+
def forward(self,
|
235 |
+
inputs: torch.FloatTensor,
|
236 |
+
targets: torch.LongTensor):
|
237 |
+
"""
|
238 |
+
:param inputs: logits, shape=[batch_size, num_classes]
|
239 |
+
:param targets: shape=[batch_size,]
|
240 |
+
:return:
|
241 |
+
"""
|
242 |
+
batch_size, num_classes = inputs.shape
|
243 |
+
|
244 |
+
if self.inputs_logits:
|
245 |
+
probs = F.softmax(inputs, dim=-1)
|
246 |
+
else:
|
247 |
+
probs = inputs
|
248 |
+
|
249 |
+
# class_mask = inputs.data.new(batch_size, num_classes).fill_(0)
|
250 |
+
class_mask = torch.zeros(size=(batch_size, num_classes), dtype=inputs.dtype, device=inputs.device)
|
251 |
+
# class_mask = Variable(class_mask)
|
252 |
+
ids = targets.view(-1, 1)
|
253 |
+
class_mask.scatter_(1, ids.data, 1.)
|
254 |
+
|
255 |
+
if inputs.is_cuda and not self.alpha.is_cuda:
|
256 |
+
self.alpha = self.alpha.cuda()
|
257 |
+
alpha = self.alpha[ids.data.view(-1)]
|
258 |
+
|
259 |
+
probs = (probs * class_mask).sum(1).view(-1, 1)
|
260 |
+
|
261 |
+
log_p = probs.log()
|
262 |
+
|
263 |
+
batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p
|
264 |
+
|
265 |
+
if self.reduction == 'mean':
|
266 |
+
loss = batch_loss.mean()
|
267 |
+
elif self.reduction == 'sum':
|
268 |
+
loss = batch_loss.sum()
|
269 |
+
else:
|
270 |
+
loss = batch_loss
|
271 |
+
return loss
|
272 |
+
|
273 |
+
|
274 |
+
class HingeLoss(_Loss):
|
275 |
+
@staticmethod
|
276 |
+
def demo1():
|
277 |
+
inputs: torch.FloatTensor = torch.randn(size=(2, 3), dtype=torch.float32)
|
278 |
+
targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
|
279 |
+
|
280 |
+
hinge_loss = HingeLoss(
|
281 |
+
margin_list=[300, 433, 50],
|
282 |
+
reduction='mean',
|
283 |
+
)
|
284 |
+
loss = hinge_loss.forward(inputs=inputs, targets=targets)
|
285 |
+
print(loss)
|
286 |
+
return
|
287 |
+
|
288 |
+
def __init__(self,
|
289 |
+
margin_list: List[float],
|
290 |
+
max_margin: float = 0.5,
|
291 |
+
scale: float = 1.0,
|
292 |
+
weight: Optional[torch.Tensor] = None,
|
293 |
+
reduction: str = 'mean') -> None:
|
294 |
+
super(HingeLoss, self).__init__(None, None, reduction)
|
295 |
+
|
296 |
+
self.max_margin = max_margin
|
297 |
+
self.scale = scale
|
298 |
+
self.weight = weight
|
299 |
+
|
300 |
+
margin_list = np.array(margin_list)
|
301 |
+
margin_list = margin_list * (max_margin / np.max(margin_list))
|
302 |
+
self.margin_list = torch.tensor(margin_list, dtype=torch.float32)
|
303 |
+
|
304 |
+
def forward(self,
|
305 |
+
inputs: torch.FloatTensor,
|
306 |
+
targets: torch.LongTensor
|
307 |
+
):
|
308 |
+
"""
|
309 |
+
:param inputs: logits, shape=[batch_size, num_classes]
|
310 |
+
:param targets: shape=[batch_size,]
|
311 |
+
:return:
|
312 |
+
"""
|
313 |
+
batch_size, num_classes = inputs.shape
|
314 |
+
one_hot_targets = F.one_hot(targets, num_classes=num_classes)
|
315 |
+
margin_list = torch.unsqueeze(self.margin_list, dim=0)
|
316 |
+
|
317 |
+
batch_margin = torch.sum(margin_list * one_hot_targets, dim=-1)
|
318 |
+
batch_margin = torch.unsqueeze(batch_margin, dim=-1)
|
319 |
+
inputs_margin = inputs - batch_margin
|
320 |
+
|
321 |
+
# 将类别对应的 logits 值减小一点, 以形成 margin 边界.
|
322 |
+
logits = torch.where(one_hot_targets > 0, inputs_margin, inputs)
|
323 |
+
|
324 |
+
loss = F.cross_entropy(
|
325 |
+
input=self.scale * logits,
|
326 |
+
target=targets,
|
327 |
+
weight=self.weight,
|
328 |
+
reduction=self.reduction,
|
329 |
+
)
|
330 |
+
return loss
|
331 |
+
|
332 |
+
|
333 |
+
class HingeLinear(nn.Module):
|
334 |
+
"""
|
335 |
+
use this instead of `HingeLoss`, then you can combine it with `FocalLoss` or others.
|
336 |
+
"""
|
337 |
+
def __init__(self,
|
338 |
+
margin_list: List[float],
|
339 |
+
max_margin: float = 0.5,
|
340 |
+
scale: float = 1.0,
|
341 |
+
weight: Optional[torch.Tensor] = None
|
342 |
+
) -> None:
|
343 |
+
super(HingeLinear, self).__init__()
|
344 |
+
|
345 |
+
self.max_margin = max_margin
|
346 |
+
self.scale = scale
|
347 |
+
self.weight = weight
|
348 |
+
|
349 |
+
margin_list = np.array(margin_list)
|
350 |
+
margin_list = margin_list * (max_margin / np.max(margin_list))
|
351 |
+
self.margin_list = torch.tensor(margin_list, dtype=torch.float32)
|
352 |
+
|
353 |
+
def forward(self,
|
354 |
+
inputs: torch.FloatTensor,
|
355 |
+
targets: torch.LongTensor
|
356 |
+
):
|
357 |
+
"""
|
358 |
+
:param inputs: logits, shape=[batch_size, num_classes]
|
359 |
+
:param targets: shape=[batch_size,]
|
360 |
+
:return:
|
361 |
+
"""
|
362 |
+
if self.training and targets is not None:
|
363 |
+
batch_size, num_classes = inputs.shape
|
364 |
+
one_hot_targets = F.one_hot(targets, num_classes=num_classes)
|
365 |
+
margin_list = torch.unsqueeze(self.margin_list, dim=0)
|
366 |
+
|
367 |
+
batch_margin = torch.sum(margin_list * one_hot_targets, dim=-1)
|
368 |
+
batch_margin = torch.unsqueeze(batch_margin, dim=-1)
|
369 |
+
inputs_margin = inputs - batch_margin
|
370 |
+
|
371 |
+
# 将类别对应的 logits 值减小一点, 以形成 margin 边界.
|
372 |
+
logits = torch.where(one_hot_targets > 0, inputs_margin, inputs)
|
373 |
+
logits = logits * self.scale
|
374 |
+
else:
|
375 |
+
logits = inputs
|
376 |
+
return logits
|
377 |
+
|
378 |
+
|
379 |
+
class LDAMLoss(_Loss):
|
380 |
+
"""
|
381 |
+
https://arxiv.org/abs/1906.07413
|
382 |
+
"""
|
383 |
+
@staticmethod
|
384 |
+
def demo1():
|
385 |
+
inputs: torch.FloatTensor = torch.randn(size=(2, 3), dtype=torch.float32)
|
386 |
+
targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
|
387 |
+
|
388 |
+
ldam_loss = LDAMLoss(
|
389 |
+
num_samples_each_class=[300, 433, 50],
|
390 |
+
reduction='mean',
|
391 |
+
)
|
392 |
+
loss = ldam_loss.forward(inputs=inputs, targets=targets)
|
393 |
+
print(loss)
|
394 |
+
return
|
395 |
+
|
396 |
+
def __init__(self,
|
397 |
+
num_samples_each_class: List[int],
|
398 |
+
max_margin: float = 0.5,
|
399 |
+
scale: float = 30.0,
|
400 |
+
weight: Optional[torch.Tensor] = None,
|
401 |
+
reduction: str = 'mean') -> None:
|
402 |
+
super(LDAMLoss, self).__init__(None, None, reduction)
|
403 |
+
|
404 |
+
margin_list = np.power(num_samples_each_class, -0.25)
|
405 |
+
margin_list = margin_list * (max_margin / np.max(margin_list))
|
406 |
+
|
407 |
+
self.num_samples_each_class = num_samples_each_class
|
408 |
+
self.margin_list = torch.tensor(margin_list, dtype=torch.float32)
|
409 |
+
self.scale = scale
|
410 |
+
self.weight = weight
|
411 |
+
|
412 |
+
def forward(self,
|
413 |
+
inputs: torch.FloatTensor,
|
414 |
+
targets: torch.LongTensor
|
415 |
+
):
|
416 |
+
"""
|
417 |
+
:param inputs: logits, shape=[batch_size, num_classes]
|
418 |
+
:param targets: shape=[batch_size,]
|
419 |
+
:return:
|
420 |
+
"""
|
421 |
+
batch_size, num_classes = inputs.shape
|
422 |
+
one_hot_targets = F.one_hot(targets, num_classes=num_classes)
|
423 |
+
margin_list = torch.unsqueeze(self.margin_list, dim=0)
|
424 |
+
|
425 |
+
batch_margin = torch.sum(margin_list * one_hot_targets, dim=-1)
|
426 |
+
batch_margin = torch.unsqueeze(batch_margin, dim=-1)
|
427 |
+
inputs_margin = inputs - batch_margin
|
428 |
+
|
429 |
+
# 将类别对应的 logits 值减小一点, 以形成 margin 边界.
|
430 |
+
logits = torch.where(one_hot_targets > 0, inputs_margin, inputs)
|
431 |
+
|
432 |
+
loss = F.cross_entropy(
|
433 |
+
input=self.scale * logits,
|
434 |
+
target=targets,
|
435 |
+
weight=self.weight,
|
436 |
+
reduction=self.reduction,
|
437 |
+
)
|
438 |
+
return loss
|
439 |
+
|
440 |
+
|
441 |
+
class NegativeEntropy(_Loss):
|
442 |
+
def __init__(self,
|
443 |
+
reduction: str = 'mean',
|
444 |
+
inputs_logits: bool = True) -> None:
|
445 |
+
super(NegativeEntropy, self).__init__(None, None, reduction)
|
446 |
+
self.inputs_logits = inputs_logits
|
447 |
+
|
448 |
+
def forward(self,
|
449 |
+
inputs: torch.FloatTensor,
|
450 |
+
targets: torch.LongTensor):
|
451 |
+
if self.inputs_logits:
|
452 |
+
probs = F.softmax(inputs, dim=-1)
|
453 |
+
log_probs = torch.nn.functional.log_softmax(probs, dim=-1)
|
454 |
+
else:
|
455 |
+
probs = inputs
|
456 |
+
log_probs = torch.log(probs)
|
457 |
+
|
458 |
+
weighted_negative_likelihood = - log_probs * probs
|
459 |
+
|
460 |
+
loss = - weighted_negative_likelihood.sum()
|
461 |
+
return loss
|
462 |
+
|
463 |
+
|
464 |
+
class LargeMarginSoftMaxLoss(_Loss):
|
465 |
+
"""
|
466 |
+
Alias: L-Softmax
|
467 |
+
|
468 |
+
https://arxiv.org/abs/1612.02295
|
469 |
+
https://github.com/wy1iu/LargeMargin_Softmax_Loss
|
470 |
+
https://github.com/amirhfarzaneh/lsoftmax-pytorch/blob/master/lsoftmax.py
|
471 |
+
|
472 |
+
参考链接:
|
473 |
+
https://www.jianshu.com/p/06cc3f84aa85
|
474 |
+
|
475 |
+
论文认为, softmax 和 cross entropy 的组合, 没有明确鼓励对特征进行判别学习.
|
476 |
+
|
477 |
+
"""
|
478 |
+
def __init__(self,
|
479 |
+
reduction: str = 'mean') -> None:
|
480 |
+
super(LargeMarginSoftMaxLoss, self).__init__(None, None, reduction)
|
481 |
+
|
482 |
+
|
483 |
+
class AngularSoftMaxLoss(_Loss):
|
484 |
+
"""
|
485 |
+
Alias: A-Softmax
|
486 |
+
|
487 |
+
https://arxiv.org/abs/1704.08063
|
488 |
+
|
489 |
+
https://github.com/woshildh/a-softmax_pytorch/blob/master/a_softmax.py
|
490 |
+
|
491 |
+
参考链接:
|
492 |
+
https://www.jianshu.com/p/06cc3f84aa85
|
493 |
+
|
494 |
+
好像作者认为人脸是一个球面, 所以将向量转换到一个球面上是有帮助的.
|
495 |
+
"""
|
496 |
+
def __init__(self,
|
497 |
+
reduction: str = 'mean') -> None:
|
498 |
+
super(AngularSoftMaxLoss, self).__init__(None, None, reduction)
|
499 |
+
|
500 |
+
|
501 |
+
class AdditiveMarginSoftMax(_Loss):
|
502 |
+
"""
|
503 |
+
Alias: AM-Softmax
|
504 |
+
|
505 |
+
https://arxiv.org/abs/1801.05599
|
506 |
+
|
507 |
+
Large Margin Cosine Loss
|
508 |
+
https://arxiv.org/abs/1801.09414
|
509 |
+
|
510 |
+
参考链接:
|
511 |
+
https://www.jianshu.com/p/06cc3f84aa85
|
512 |
+
|
513 |
+
说明:
|
514 |
+
相对于普通的 对 logits 做 softmax,
|
515 |
+
它将真实标签对应的 logit 值减去 m, 来让模型它该值调整得更大一些.
|
516 |
+
另外, 它还将每个 logits 乘以 s, 这可以控制各 logits 之间的相对大小.
|
517 |
+
根 HingeLoss 有点像.
|
518 |
+
"""
|
519 |
+
def __init__(self,
|
520 |
+
reduction: str = 'mean') -> None:
|
521 |
+
super(AdditiveMarginSoftMax, self).__init__(None, None, reduction)
|
522 |
+
|
523 |
+
|
524 |
+
class AdditiveAngularMarginSoftMax(_Loss):
|
525 |
+
"""
|
526 |
+
Alias: ArcFace, AAM-Softmax
|
527 |
+
|
528 |
+
ArcFace: Additive Angular Margin Loss for Deep Face Recognition
|
529 |
+
https://arxiv.org/abs/1801.07698
|
530 |
+
|
531 |
+
参考代码:
|
532 |
+
https://github.com/huangkeju/AAMSoftmax-OpenMax/blob/main/AAMSoftmax%2BOvA/metrics.py
|
533 |
+
|
534 |
+
"""
|
535 |
+
@staticmethod
|
536 |
+
def demo1():
|
537 |
+
"""
|
538 |
+
角度与数值转换
|
539 |
+
pi / 180 代表 1 度,
|
540 |
+
pi / 180 = 0.01745
|
541 |
+
"""
|
542 |
+
|
543 |
+
# 度数转数值
|
544 |
+
degree = 10
|
545 |
+
result = degree * math.pi / 180
|
546 |
+
print(result)
|
547 |
+
|
548 |
+
# 数值转数度
|
549 |
+
radian = 0.2
|
550 |
+
result = radian / (math.pi / 180)
|
551 |
+
print(result)
|
552 |
+
|
553 |
+
return
|
554 |
+
|
555 |
+
def __init__(self,
|
556 |
+
hidden_size: int,
|
557 |
+
num_labels: int,
|
558 |
+
margin: float = 0.2,
|
559 |
+
scale: float = 10.0,
|
560 |
+
):
|
561 |
+
"""
|
562 |
+
:param hidden_size:
|
563 |
+
:param num_labels:
|
564 |
+
:param margin: 建议取值角度为 [10, 30], 对应的数值为 [0.1745, 0.5236]
|
565 |
+
:param scale:
|
566 |
+
"""
|
567 |
+
super(AdditiveAngularMarginSoftMax, self).__init__()
|
568 |
+
self.margin = margin
|
569 |
+
self.scale = scale
|
570 |
+
self.weight = torch.nn.Parameter(torch.FloatTensor(num_labels, hidden_size), requires_grad=True)
|
571 |
+
nn.init.xavier_uniform_(self.weight)
|
572 |
+
|
573 |
+
self.cos_margin = math.cos(self.margin)
|
574 |
+
self.sin_margin = math.sin(self.margin)
|
575 |
+
|
576 |
+
# sin(a-b) = sin(a)cos(b) - cos(a)sin(b)
|
577 |
+
# sin(pi - a) = sin(a)
|
578 |
+
|
579 |
+
self.loss = nn.CrossEntropyLoss()
|
580 |
+
|
581 |
+
def forward(self,
|
582 |
+
inputs: torch.Tensor,
|
583 |
+
label: torch.LongTensor = None
|
584 |
+
):
|
585 |
+
"""
|
586 |
+
:param inputs: shape=[batch_size, ..., hidden_size]
|
587 |
+
:param label:
|
588 |
+
:return: logits
|
589 |
+
"""
|
590 |
+
x = F.normalize(inputs)
|
591 |
+
weight = F.normalize(self.weight)
|
592 |
+
cosine = F.linear(x, weight)
|
593 |
+
|
594 |
+
if self.training:
|
595 |
+
|
596 |
+
# sin^2 + cos^2 = 1
|
597 |
+
sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1))
|
598 |
+
|
599 |
+
# cos(a+b) = cos(a)cos(b) - sin(a)sin(b)
|
600 |
+
cosine_theta_margin = cosine * self.cos_margin - sine * self.sin_margin
|
601 |
+
|
602 |
+
# when the `cosine > - self.cos_margin` there is enough space to add margin on theta.
|
603 |
+
cosine_theta_margin = torch.where(cosine > - self.cos_margin, cosine_theta_margin, cosine - (self.margin * self.sin_margin))
|
604 |
+
|
605 |
+
one_hot = torch.zeros_like(cosine)
|
606 |
+
one_hot.scatter_(1, label.view(-1, 1), 1)
|
607 |
+
|
608 |
+
#
|
609 |
+
logits = torch.where(one_hot == 1, cosine_theta_margin, cosine)
|
610 |
+
logits = logits * self.scale
|
611 |
+
else:
|
612 |
+
logits = cosine
|
613 |
+
|
614 |
+
loss = self.loss(logits, label)
|
615 |
+
# prec1 = accuracy(output.detach(), label.detach(), topk=(1,))[0]
|
616 |
+
return loss
|
617 |
+
|
618 |
+
|
619 |
+
class AdditiveAngularMarginLinear(nn.Module):
|
620 |
+
"""
|
621 |
+
Alias: ArcFace, AAM-Softmax
|
622 |
+
|
623 |
+
ArcFace: Additive Angular Margin Loss for Deep Face Recognition
|
624 |
+
https://arxiv.org/abs/1801.07698
|
625 |
+
|
626 |
+
参考代码:
|
627 |
+
https://github.com/huangkeju/AAMSoftmax-OpenMax/blob/main/AAMSoftmax%2BOvA/metrics.py
|
628 |
+
|
629 |
+
"""
|
630 |
+
@staticmethod
|
631 |
+
def demo1():
|
632 |
+
"""
|
633 |
+
角度与数值转换
|
634 |
+
pi / 180 代表 1 度,
|
635 |
+
pi / 180 = 0.01745
|
636 |
+
"""
|
637 |
+
|
638 |
+
# 度数转数值
|
639 |
+
degree = 10
|
640 |
+
result = degree * math.pi / 180
|
641 |
+
print(result)
|
642 |
+
|
643 |
+
# 数值转数度
|
644 |
+
radian = 0.2
|
645 |
+
result = radian / (math.pi / 180)
|
646 |
+
print(result)
|
647 |
+
|
648 |
+
return
|
649 |
+
|
650 |
+
@staticmethod
|
651 |
+
def demo2():
|
652 |
+
|
653 |
+
return
|
654 |
+
|
655 |
+
def __init__(self,
|
656 |
+
hidden_size: int,
|
657 |
+
num_labels: int,
|
658 |
+
margin: float = 0.2,
|
659 |
+
scale: float = 10.0,
|
660 |
+
):
|
661 |
+
"""
|
662 |
+
:param hidden_size:
|
663 |
+
:param num_labels:
|
664 |
+
:param margin: 建议取值角度为 [10, 30], 对应的数值为 [0.1745, 0.5236]
|
665 |
+
:param scale:
|
666 |
+
"""
|
667 |
+
super(AdditiveAngularMarginLinear, self).__init__()
|
668 |
+
self.margin = margin
|
669 |
+
self.scale = scale
|
670 |
+
self.weight = torch.nn.Parameter(torch.FloatTensor(num_labels, hidden_size), requires_grad=True)
|
671 |
+
nn.init.xavier_uniform_(self.weight)
|
672 |
+
|
673 |
+
self.cos_margin = math.cos(self.margin)
|
674 |
+
self.sin_margin = math.sin(self.margin)
|
675 |
+
|
676 |
+
# sin(a-b) = sin(a)cos(b) - cos(a)sin(b)
|
677 |
+
# sin(pi - a) = sin(a)
|
678 |
+
|
679 |
+
def forward(self,
|
680 |
+
inputs: torch.Tensor,
|
681 |
+
targets: torch.LongTensor = None
|
682 |
+
):
|
683 |
+
"""
|
684 |
+
:param inputs: shape=[batch_size, ..., hidden_size]
|
685 |
+
:param targets:
|
686 |
+
:return: logits
|
687 |
+
"""
|
688 |
+
x = F.normalize(inputs)
|
689 |
+
weight = F.normalize(self.weight)
|
690 |
+
cosine = F.linear(x, weight)
|
691 |
+
|
692 |
+
if self.training and targets is not None:
|
693 |
+
# sin^2 + cos^2 = 1
|
694 |
+
sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1))
|
695 |
+
|
696 |
+
# cos(a+b) = cos(a)cos(b) - sin(a)sin(b)
|
697 |
+
cosine_theta_margin = cosine * self.cos_margin - sine * self.sin_margin
|
698 |
+
|
699 |
+
# when the `cosine > - self.cos_margin` there is enough space to add margin on theta.
|
700 |
+
cosine_theta_margin = torch.where(cosine > - self.cos_margin, cosine_theta_margin, cosine - (self.margin * self.sin_margin))
|
701 |
+
|
702 |
+
one_hot = torch.zeros_like(cosine)
|
703 |
+
one_hot.scatter_(1, targets.view(-1, 1), 1)
|
704 |
+
|
705 |
+
logits = torch.where(one_hot == 1, cosine_theta_margin, cosine)
|
706 |
+
logits = logits * self.scale
|
707 |
+
else:
|
708 |
+
logits = cosine
|
709 |
+
return logits
|
710 |
+
|
711 |
+
|
712 |
+
def demo1():
|
713 |
+
HingeLoss.demo1()
|
714 |
+
return
|
715 |
+
|
716 |
+
|
717 |
+
def demo2():
|
718 |
+
AdditiveAngularMarginSoftMax.demo1()
|
719 |
+
|
720 |
+
inputs = torch.ones(size=(2, 5), dtype=torch.float32)
|
721 |
+
label: torch.LongTensor = torch.tensor(data=[0, 1], dtype=torch.long)
|
722 |
+
|
723 |
+
aam_softmax = AdditiveAngularMarginSoftMax(
|
724 |
+
hidden_size=5,
|
725 |
+
num_labels=2,
|
726 |
+
margin=1,
|
727 |
+
scale=1
|
728 |
+
)
|
729 |
+
|
730 |
+
outputs = aam_softmax.forward(inputs, label)
|
731 |
+
print(outputs)
|
732 |
+
|
733 |
+
return
|
734 |
+
|
735 |
+
|
736 |
+
if __name__ == '__main__':
|
737 |
+
# demo1()
|
738 |
+
demo2()
|
toolbox/torch/training/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
pass
|
toolbox/torch/training/metrics/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
pass
|
toolbox/torch/training/metrics/categorical_accuracy.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
from overrides import overrides
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class CategoricalAccuracy(object):
|
8 |
+
def __init__(self, top_k: int = 1, tie_break: bool = False) -> None:
|
9 |
+
if top_k > 1 and tie_break:
|
10 |
+
raise AssertionError("Tie break in Categorical Accuracy "
|
11 |
+
"can be done only for maximum (top_k = 1)")
|
12 |
+
if top_k <= 0:
|
13 |
+
raise AssertionError("top_k passed to Categorical Accuracy must be > 0")
|
14 |
+
self._top_k = top_k
|
15 |
+
self._tie_break = tie_break
|
16 |
+
self.correct_count = 0.
|
17 |
+
self.total_count = 0.
|
18 |
+
|
19 |
+
def __call__(self,
|
20 |
+
predictions: torch.Tensor,
|
21 |
+
gold_labels: torch.Tensor,
|
22 |
+
mask: Optional[torch.Tensor] = None):
|
23 |
+
|
24 |
+
# predictions, gold_labels, mask = self.unwrap_to_tensors(predictions, gold_labels, mask)
|
25 |
+
|
26 |
+
# Some sanity checks.
|
27 |
+
num_classes = predictions.size(-1)
|
28 |
+
if gold_labels.dim() != predictions.dim() - 1:
|
29 |
+
raise AssertionError("gold_labels must have dimension == predictions.size() - 1 but "
|
30 |
+
"found tensor of shape: {}".format(predictions.size()))
|
31 |
+
if (gold_labels >= num_classes).any():
|
32 |
+
raise AssertionError("A gold label passed to Categorical Accuracy contains an id >= {}, "
|
33 |
+
"the number of classes.".format(num_classes))
|
34 |
+
|
35 |
+
predictions = predictions.view((-1, num_classes))
|
36 |
+
gold_labels = gold_labels.view(-1).long()
|
37 |
+
if not self._tie_break:
|
38 |
+
# Top K indexes of the predictions (or fewer, if there aren't K of them).
|
39 |
+
# Special case topk == 1, because it's common and .max() is much faster than .topk().
|
40 |
+
if self._top_k == 1:
|
41 |
+
top_k = predictions.max(-1)[1].unsqueeze(-1)
|
42 |
+
else:
|
43 |
+
top_k = predictions.topk(min(self._top_k, predictions.shape[-1]), -1)[1]
|
44 |
+
|
45 |
+
# This is of shape (batch_size, ..., top_k).
|
46 |
+
correct = top_k.eq(gold_labels.unsqueeze(-1)).float()
|
47 |
+
else:
|
48 |
+
# prediction is correct if gold label falls on any of the max scores. distribute score by tie_counts
|
49 |
+
max_predictions = predictions.max(-1)[0]
|
50 |
+
max_predictions_mask = predictions.eq(max_predictions.unsqueeze(-1))
|
51 |
+
# max_predictions_mask is (rows X num_classes) and gold_labels is (batch_size)
|
52 |
+
# ith entry in gold_labels points to index (0-num_classes) for ith row in max_predictions
|
53 |
+
# For each row check if index pointed by gold_label is was 1 or not (among max scored classes)
|
54 |
+
correct = max_predictions_mask[torch.arange(gold_labels.numel()).long(), gold_labels].float()
|
55 |
+
tie_counts = max_predictions_mask.sum(-1)
|
56 |
+
correct /= tie_counts.float()
|
57 |
+
correct.unsqueeze_(-1)
|
58 |
+
|
59 |
+
if mask is not None:
|
60 |
+
correct *= mask.view(-1, 1).float()
|
61 |
+
self.total_count += mask.sum()
|
62 |
+
else:
|
63 |
+
self.total_count += gold_labels.numel()
|
64 |
+
self.correct_count += correct.sum()
|
65 |
+
|
66 |
+
def get_metric(self, reset: bool = False):
|
67 |
+
"""
|
68 |
+
Returns
|
69 |
+
-------
|
70 |
+
The accumulated accuracy.
|
71 |
+
"""
|
72 |
+
if self.total_count > 1e-12:
|
73 |
+
accuracy = float(self.correct_count) / float(self.total_count)
|
74 |
+
else:
|
75 |
+
accuracy = 0.0
|
76 |
+
if reset:
|
77 |
+
self.reset()
|
78 |
+
return {'accuracy': accuracy}
|
79 |
+
|
80 |
+
def reset(self):
|
81 |
+
self.correct_count = 0.0
|
82 |
+
self.total_count = 0.0
|
toolbox/torch/training/metrics/verbose_categorical_accuracy.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from typing import Dict, List, Optional
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class CategoricalAccuracyVerbose(object):
|
10 |
+
def __init__(self,
|
11 |
+
index_to_token: Dict[int, str],
|
12 |
+
label_namespace: str = "labels",
|
13 |
+
top_k: int = 1,
|
14 |
+
) -> None:
|
15 |
+
if top_k <= 0:
|
16 |
+
raise AssertionError("top_k passed to Categorical Accuracy must be > 0")
|
17 |
+
self._index_to_token = index_to_token
|
18 |
+
self._label_namespace = label_namespace
|
19 |
+
self._top_k = top_k
|
20 |
+
self.correct_count = 0.
|
21 |
+
self.total_count = 0.
|
22 |
+
self.label_correct_count = dict()
|
23 |
+
self.label_total_count = dict()
|
24 |
+
|
25 |
+
def __call__(self,
|
26 |
+
predictions: torch.Tensor,
|
27 |
+
gold_labels: torch.Tensor,
|
28 |
+
mask: Optional[torch.Tensor] = None):
|
29 |
+
num_classes = predictions.size(-1)
|
30 |
+
if gold_labels.dim() != predictions.dim() - 1:
|
31 |
+
raise AssertionError("gold_labels must have dimension == predictions.size() - 1 but "
|
32 |
+
"found tensor of shape: {}".format(predictions.size()))
|
33 |
+
if (gold_labels >= num_classes).any():
|
34 |
+
raise AssertionError("A gold label passed to Categorical Accuracy contains an id >= {}, "
|
35 |
+
"the number of classes.".format(num_classes))
|
36 |
+
|
37 |
+
predictions = predictions.view((-1, num_classes))
|
38 |
+
gold_labels = gold_labels.view(-1).long()
|
39 |
+
|
40 |
+
# Top K indexes of the predictions (or fewer, if there aren't K of them).
|
41 |
+
# Special case topk == 1, because it's common and .max() is much faster than .topk().
|
42 |
+
if self._top_k == 1:
|
43 |
+
top_k = predictions.max(-1)[1].unsqueeze(-1)
|
44 |
+
else:
|
45 |
+
top_k = predictions.topk(min(self._top_k, predictions.shape[-1]), -1)[1]
|
46 |
+
|
47 |
+
# This is of shape (batch_size, ..., top_k).
|
48 |
+
correct = top_k.eq(gold_labels.unsqueeze(-1)).float()
|
49 |
+
|
50 |
+
if mask is not None:
|
51 |
+
correct *= mask.view(-1, 1).float()
|
52 |
+
self.total_count += mask.sum()
|
53 |
+
else:
|
54 |
+
self.total_count += gold_labels.numel()
|
55 |
+
self.correct_count += correct.sum()
|
56 |
+
|
57 |
+
labels: List[int] = np.unique(gold_labels.cpu().numpy()).tolist()
|
58 |
+
for label in labels:
|
59 |
+
label_mask = (gold_labels == label)
|
60 |
+
|
61 |
+
label_correct = correct * label_mask.view(-1, 1).float()
|
62 |
+
label_correct = int(label_correct.sum())
|
63 |
+
label_count = int(label_mask.sum())
|
64 |
+
|
65 |
+
label_str = self._index_to_token[label]
|
66 |
+
if label_str in self.label_correct_count:
|
67 |
+
self.label_correct_count[label_str] += label_correct
|
68 |
+
else:
|
69 |
+
self.label_correct_count[label_str] = label_correct
|
70 |
+
|
71 |
+
if label_str in self.label_total_count:
|
72 |
+
self.label_total_count[label_str] += label_count
|
73 |
+
else:
|
74 |
+
self.label_total_count[label_str] = label_count
|
75 |
+
|
76 |
+
def get_metric(self, reset: bool = False):
|
77 |
+
"""
|
78 |
+
Returns
|
79 |
+
-------
|
80 |
+
The accumulated accuracy.
|
81 |
+
"""
|
82 |
+
result = dict()
|
83 |
+
if self.total_count > 1e-12:
|
84 |
+
accuracy = float(self.correct_count) / float(self.total_count)
|
85 |
+
else:
|
86 |
+
accuracy = 0.0
|
87 |
+
result['accuracy'] = accuracy
|
88 |
+
|
89 |
+
for label in self.label_total_count.keys():
|
90 |
+
total = self.label_total_count[label]
|
91 |
+
correct = self.label_correct_count.get(label, 0.0)
|
92 |
+
label_accuracy = correct / total
|
93 |
+
result[label] = label_accuracy
|
94 |
+
|
95 |
+
if reset:
|
96 |
+
self.reset()
|
97 |
+
return result
|
98 |
+
|
99 |
+
def reset(self):
|
100 |
+
self.correct_count = 0.0
|
101 |
+
self.total_count = 0.0
|
102 |
+
self.label_correct_count = dict()
|
103 |
+
self.label_total_count = dict()
|
104 |
+
|
105 |
+
|
106 |
+
def demo1():
|
107 |
+
|
108 |
+
categorical_accuracy_verbose = CategoricalAccuracyVerbose(
|
109 |
+
index_to_token={0: '0', 1: '1'},
|
110 |
+
top_k=2,
|
111 |
+
)
|
112 |
+
|
113 |
+
predictions = torch.randn(size=(2, 3), dtype=torch.float32)
|
114 |
+
gold_labels = torch.ones(size=(2,), dtype=torch.long)
|
115 |
+
# print(predictions)
|
116 |
+
# print(gold_labels)
|
117 |
+
|
118 |
+
categorical_accuracy_verbose(
|
119 |
+
predictions=predictions,
|
120 |
+
gold_labels=gold_labels,
|
121 |
+
)
|
122 |
+
metric = categorical_accuracy_verbose.get_metric()
|
123 |
+
print(metric)
|
124 |
+
return
|
125 |
+
|
126 |
+
|
127 |
+
if __name__ == '__main__':
|
128 |
+
demo1()
|
toolbox/torch/training/trainer/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
if __name__ == '__main__':
|
5 |
+
pass
|
toolbox/torch/training/trainer/trainer.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
if __name__ == '__main__':
|
5 |
+
pass
|
toolbox/torch/utils/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
if __name__ == '__main__':
|
5 |
+
pass
|
toolbox/torchaudio/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
if __name__ == '__main__':
|
5 |
+
pass
|
toolbox/torchaudio/configuration_utils.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import copy
|
4 |
+
import os
|
5 |
+
from typing import Any, Dict, Union
|
6 |
+
|
7 |
+
import yaml
|
8 |
+
|
9 |
+
|
10 |
+
CONFIG_FILE = "config.yaml"
|
11 |
+
|
12 |
+
|
13 |
+
class PretrainedConfig(object):
|
14 |
+
def __init__(self, **kwargs):
|
15 |
+
pass
|
16 |
+
|
17 |
+
@classmethod
|
18 |
+
def _dict_from_yaml_file(cls, yaml_file: Union[str, os.PathLike]):
|
19 |
+
with open(yaml_file, encoding="utf-8") as f:
|
20 |
+
config_dict = yaml.safe_load(f)
|
21 |
+
return config_dict
|
22 |
+
|
23 |
+
@classmethod
|
24 |
+
def get_config_dict(
|
25 |
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike]
|
26 |
+
) -> Dict[str, Any]:
|
27 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
28 |
+
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_FILE)
|
29 |
+
else:
|
30 |
+
config_file = pretrained_model_name_or_path
|
31 |
+
config_dict = cls._dict_from_yaml_file(config_file)
|
32 |
+
return config_dict
|
33 |
+
|
34 |
+
@classmethod
|
35 |
+
def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
|
36 |
+
for k, v in kwargs.items():
|
37 |
+
if k in config_dict.keys():
|
38 |
+
config_dict[k] = v
|
39 |
+
config = cls(**config_dict)
|
40 |
+
return config
|
41 |
+
|
42 |
+
@classmethod
|
43 |
+
def from_pretrained(
|
44 |
+
cls,
|
45 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
46 |
+
**kwargs,
|
47 |
+
):
|
48 |
+
config_dict = cls.get_config_dict(pretrained_model_name_or_path)
|
49 |
+
return cls.from_dict(config_dict, **kwargs)
|
50 |
+
|
51 |
+
def to_dict(self):
|
52 |
+
output = copy.deepcopy(self.__dict__)
|
53 |
+
return output
|
54 |
+
|
55 |
+
def to_yaml_file(self, yaml_file_path: Union[str, os.PathLike]):
|
56 |
+
config_dict = self.to_dict()
|
57 |
+
|
58 |
+
with open(yaml_file_path, "w", encoding="utf-8") as writer:
|
59 |
+
yaml.safe_dump(config_dict, writer)
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == '__main__':
|
63 |
+
pass
|