HoneyTian commited on
Commit
69ad385
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. .gitignore +18 -0
  3. Dockerfile +21 -0
  4. README.md +11 -0
  5. examples/vm_sound_classification/conv2d_classifier.yaml +45 -0
  6. examples/vm_sound_classification/requirements.txt +10 -0
  7. examples/vm_sound_classification/run.sh +188 -0
  8. examples/vm_sound_classification/step_1_prepare_data.py +150 -0
  9. examples/vm_sound_classification/step_2_make_vocabulary.py +51 -0
  10. examples/vm_sound_classification/step_3_train_model.py +331 -0
  11. examples/vm_sound_classification/step_4_evaluation_model.py +128 -0
  12. examples/vm_sound_classification/step_5_export_models.py +106 -0
  13. examples/vm_sound_classification/step_6_infer.py +91 -0
  14. examples/vm_sound_classification/step_7_test_model.py +93 -0
  15. examples/vm_sound_classification/stop.sh +3 -0
  16. examples/vm_sound_classification8/requirements.txt +9 -0
  17. examples/vm_sound_classification8/run.sh +157 -0
  18. examples/vm_sound_classification8/step_1_prepare_data.py +156 -0
  19. examples/vm_sound_classification8/step_2_make_vocabulary.py +69 -0
  20. examples/vm_sound_classification8/step_3_train_global_model.py +328 -0
  21. examples/vm_sound_classification8/step_4_train_country_model.py +349 -0
  22. examples/vm_sound_classification8/step_5_train_union.py +499 -0
  23. examples/vm_sound_classification8/stop.sh +3 -0
  24. install.sh +64 -0
  25. main.py +172 -0
  26. project_settings.py +19 -0
  27. requirements.txt +12 -0
  28. script/install_nvidia_driver.sh +184 -0
  29. script/install_python.sh +129 -0
  30. toolbox/__init__.py +5 -0
  31. toolbox/json/__init__.py +6 -0
  32. toolbox/json/misc.py +63 -0
  33. toolbox/os/__init__.py +6 -0
  34. toolbox/os/command.py +59 -0
  35. toolbox/os/environment.py +114 -0
  36. toolbox/os/other.py +9 -0
  37. toolbox/torch/__init__.py +5 -0
  38. toolbox/torch/modules/__init__.py +6 -0
  39. toolbox/torch/modules/gaussian_mixture.py +173 -0
  40. toolbox/torch/modules/highway.py +30 -0
  41. toolbox/torch/modules/loss.py +738 -0
  42. toolbox/torch/training/__init__.py +6 -0
  43. toolbox/torch/training/metrics/__init__.py +6 -0
  44. toolbox/torch/training/metrics/categorical_accuracy.py +82 -0
  45. toolbox/torch/training/metrics/verbose_categorical_accuracy.py +128 -0
  46. toolbox/torch/training/trainer/__init__.py +5 -0
  47. toolbox/torch/training/trainer/trainer.py +5 -0
  48. toolbox/torch/utils/__init__.py +5 -0
  49. toolbox/torchaudio/__init__.py +5 -0
  50. toolbox/torchaudio/configuration_utils.py +63 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ .git/
3
+ .idea/
4
+
5
+ **/file_dir
6
+ **/flagged/
7
+ **/log/
8
+ **/logs/
9
+ **/__pycache__/
10
+
11
+ data/
12
+ docs/
13
+ dotenv/
14
+ trained_models/
15
+ temp/
16
+
17
+ #**/*.wav
18
+ **/*.xlsx
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.8
2
+
3
+ WORKDIR /code
4
+
5
+ COPY . /code
6
+
7
+ RUN pip install --upgrade pip
8
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
9
+
10
+ RUN useradd -m -u 1000 user
11
+
12
+ USER user
13
+
14
+ ENV HOME=/home/user \
15
+ PATH=/home/user/.local/bin:$PATH
16
+
17
+ WORKDIR $HOME/app
18
+
19
+ COPY --chown=user . $HOME/app
20
+
21
+ CMD ["python3", "main.py"]
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: VM Sound Classification
3
+ emoji: 🐢
4
+ colorFrom: purple
5
+ colorTo: blue
6
+ sdk: docker
7
+ pinned: false
8
+ license: apache-2.0
9
+ ---
10
+
11
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
examples/vm_sound_classification/conv2d_classifier.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "cnn_audio_classifier"
2
+
3
+ mel_spectrogram_param:
4
+ sample_rate: 8000
5
+ n_fft: 512
6
+ win_length: 200
7
+ hop_length: 80
8
+ f_min: 10
9
+ f_max: 3800
10
+ window_fn: hamming
11
+ n_mels: 80
12
+
13
+ conv2d_block_param_list:
14
+ - batch_norm: true
15
+ in_channels: 1
16
+ out_channels: 16
17
+ kernel_size: 3
18
+ stride: 1
19
+ dilation: 3
20
+ activation: relu
21
+ dropout: 0.1
22
+ - in_channels: 16
23
+ out_channels: 16
24
+ kernel_size: 5
25
+ stride: 2
26
+ dilation: 3
27
+ activation: relu
28
+ dropout: 0.1
29
+ - in_channels: 16
30
+ out_channels: 16
31
+ kernel_size: 3
32
+ stride: 1
33
+ dilation: 2
34
+ activation: relu
35
+ dropout: 0.1
36
+
37
+ cls_head_param:
38
+ input_dim: 432
39
+ num_layers: 2
40
+ hidden_dims:
41
+ - 128
42
+ - 32
43
+ activations: relu
44
+ dropout: 0.1
45
+ num_labels: 8
examples/vm_sound_classification/requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.13.1
2
+ torchaudio==0.13.1
3
+ fsspec==2022.1.0
4
+ librosa==0.9.2
5
+ pandas==1.1.5
6
+ openpyxl==3.0.9
7
+ xlrd==1.2.0
8
+ tqdm==4.64.1
9
+ overrides==1.9.0
10
+ pyyaml==6.0.1
examples/vm_sound_classification/run.sh ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+ sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification3 \
6
+ --filename_patterns "E:/programmer/asr_datasets/voicemail/wav_finished/en-US/wav_finished/*/*.wav \
7
+ E:/programmer/asr_datasets/voicemail/wav_finished/id-ID/wav_finished/*/*.wav" \
8
+
9
+ sh run.sh --stage 0 --stop_stage 1 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification3 \
10
+ --filename_patterns "E:/programmer/asr_datasets/voicemail/wav_finished/en-US/wav_finished/*/*.wav \
11
+ E:/programmer/asr_datasets/voicemail/wav_finished/id-ID/wav_finished/*/*.wav" \
12
+
13
+ sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification3
14
+ sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification3
15
+
16
+ sh run.sh --stage 3 --stop_stage 5 --system_version centos --file_folder_name file_dir --final_model_name vm_sound_classification8-ch16 \
17
+ --filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
18
+
19
+
20
+ "
21
+
22
+ END
23
+
24
+
25
+ # params
26
+ system_version="windows";
27
+ verbose=true;
28
+ stage=0 # start from 0 if you need to start from data preparation
29
+ stop_stage=9
30
+
31
+ work_dir="$(pwd)"
32
+ file_folder_name=file_folder_name
33
+ final_model_name=final_model_name
34
+ filename_patterns="/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
35
+ nohup_name=nohup.out
36
+
37
+ country=en-US
38
+
39
+ # model params
40
+ batch_size=64
41
+ max_epochs=200
42
+ save_top_k=10
43
+ patience=5
44
+
45
+
46
+ # parse options
47
+ while true; do
48
+ [ -z "${1:-}" ] && break; # break if there are no arguments
49
+ case "$1" in
50
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
51
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
52
+ old_value="(eval echo \\$$name)";
53
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
54
+ was_bool=true;
55
+ else
56
+ was_bool=false;
57
+ fi
58
+
59
+ # Set the variable to the right value-- the escaped quotes make it work if
60
+ # the option had spaces, like --cmd "queue.pl -sync y"
61
+ eval "${name}=\"$2\"";
62
+
63
+ # Check that Boolean-valued arguments are really Boolean.
64
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
65
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
66
+ exit 1;
67
+ fi
68
+ shift 2;
69
+ ;;
70
+
71
+ *) break;
72
+ esac
73
+ done
74
+
75
+ file_dir="${work_dir}/${file_folder_name}"
76
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
77
+
78
+ dataset="${file_dir}/dataset.xlsx"
79
+ train_dataset="${file_dir}/train.xlsx"
80
+ valid_dataset="${file_dir}/valid.xlsx"
81
+ evaluation_file="${file_dir}/evaluation.xlsx"
82
+ vocabulary_dir="${file_dir}/vocabulary"
83
+
84
+ $verbose && echo "system_version: ${system_version}"
85
+ $verbose && echo "file_folder_name: ${file_folder_name}"
86
+
87
+ if [ $system_version == "windows" ]; then
88
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/vm_sound_classification/Scripts/python.exe'
89
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
90
+ #source /data/local/bin/vm_sound_classification/bin/activate
91
+ alias python3='/data/local/bin/vm_sound_classification/bin/python3'
92
+ fi
93
+
94
+
95
+ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
96
+ $verbose && echo "stage 0: prepare data"
97
+ cd "${work_dir}" || exit 1
98
+ python3 step_1_prepare_data.py \
99
+ --file_dir "${file_dir}" \
100
+ --filename_patterns "${filename_patterns}" \
101
+ --train_dataset "${train_dataset}" \
102
+ --valid_dataset "${valid_dataset}" \
103
+
104
+ fi
105
+
106
+
107
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
108
+ $verbose && echo "stage 1: make vocabulary"
109
+ cd "${work_dir}" || exit 1
110
+ python3 step_2_make_vocabulary.py \
111
+ --vocabulary_dir "${vocabulary_dir}" \
112
+ --train_dataset "${train_dataset}" \
113
+ --valid_dataset "${valid_dataset}" \
114
+
115
+ fi
116
+
117
+
118
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
119
+ $verbose && echo "stage 2: train model"
120
+ cd "${work_dir}" || exit 1
121
+ python3 step_3_train_model.py \
122
+ --vocabulary_dir "${vocabulary_dir}" \
123
+ --train_dataset "${train_dataset}" \
124
+ --valid_dataset "${valid_dataset}" \
125
+ --serialization_dir "${file_dir}" \
126
+
127
+ fi
128
+
129
+
130
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
131
+ $verbose && echo "stage 3: test model"
132
+ cd "${work_dir}" || exit 1
133
+ python3 step_4_evaluation_model.py \
134
+ --dataset "${dataset}" \
135
+ --vocabulary_dir "${vocabulary_dir}" \
136
+ --model_dir "${file_dir}/best" \
137
+ --output_file "${evaluation_file}" \
138
+
139
+ fi
140
+
141
+
142
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
143
+ $verbose && echo "stage 4: export model"
144
+ cd "${work_dir}" || exit 1
145
+ python3 step_5_export_models.py \
146
+ --vocabulary_dir "${vocabulary_dir}" \
147
+ --model_dir "${file_dir}/best" \
148
+ --serialization_dir "${file_dir}" \
149
+
150
+ fi
151
+
152
+
153
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
154
+ $verbose && echo "stage 5: collect files"
155
+ cd "${work_dir}" || exit 1
156
+
157
+ mkdir -p ${final_model_dir}
158
+
159
+ cp "${file_dir}/best"/* "${final_model_dir}"
160
+ cp -r "${file_dir}/vocabulary" "${final_model_dir}"
161
+
162
+ cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx"
163
+
164
+ cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip"
165
+ cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip"
166
+ cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip"
167
+ cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip"
168
+
169
+ cd "${final_model_dir}/.." || exit 1;
170
+
171
+ if [ -e "${final_model_name}.zip" ]; then
172
+ rm -rf "${final_model_name}_backup.zip"
173
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
174
+ fi
175
+
176
+ zip -r "${final_model_name}.zip" "${final_model_name}"
177
+ rm -rf "${final_model_name}"
178
+
179
+ fi
180
+
181
+
182
+ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
183
+ $verbose && echo "stage 6: clear file_dir"
184
+ cd "${work_dir}" || exit 1
185
+
186
+ rm -rf "${file_dir}";
187
+
188
+ fi
examples/vm_sound_classification/step_1_prepare_data.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from glob import glob
5
+ import json
6
+ import os
7
+ from pathlib import Path
8
+ import random
9
+ import sys
10
+
11
+ pwd = os.path.abspath(os.path.dirname(__file__))
12
+ sys.path.append(os.path.join(pwd, "../../"))
13
+
14
+ import pandas as pd
15
+ from scipy.io import wavfile
16
+ from tqdm import tqdm
17
+
18
+
19
+ def get_args():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--file_dir", default="./", type=str)
22
+ parser.add_argument("--task", default="default", type=str)
23
+ parser.add_argument("--filename_patterns", type=str)
24
+
25
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
26
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
27
+
28
+ args = parser.parse_args()
29
+ return args
30
+
31
+
32
+ def get_dataset(args):
33
+ filename_patterns = args.filename_patterns
34
+ filename_patterns = filename_patterns.split(" ")
35
+ print(filename_patterns)
36
+
37
+ file_dir = Path(args.file_dir)
38
+ file_dir.mkdir(exist_ok=True)
39
+
40
+ # label3_map = {
41
+ # "bell": "voicemail",
42
+ # "white_noise": "mute",
43
+ # "low_white_noise": "mute",
44
+ # "high_white_noise": "mute",
45
+ # # "music": "music",
46
+ # "mute": "mute",
47
+ # "noise": "voice_or_noise",
48
+ # "noise_mute": "voice_or_noise",
49
+ # "voice": "voice_or_noise",
50
+ # "voicemail": "voicemail",
51
+ # }
52
+ label8_map = {
53
+ "bell": "bell",
54
+ "white_noise": "white_noise",
55
+ "low_white_noise": "white_noise",
56
+ "high_white_noise": "white_noise",
57
+ "music": "music",
58
+ "mute": "mute",
59
+ "noise": "noise",
60
+ "noise_mute": "noise_mute",
61
+ "voice": "voice",
62
+ "voicemail": "voicemail",
63
+ }
64
+
65
+ result = list()
66
+ for filename_pattern in filename_patterns:
67
+ filename_list = glob(filename_pattern)
68
+ for filename in tqdm(filename_list):
69
+ filename = Path(filename)
70
+ sample_rate, signal = wavfile.read(filename.as_posix())
71
+ if len(signal) < sample_rate * 2:
72
+ continue
73
+
74
+ folder = filename.parts[-2]
75
+ country = filename.parts[-4]
76
+
77
+ if folder not in label8_map.keys():
78
+ continue
79
+
80
+ labels = label8_map[folder]
81
+
82
+ random1 = random.random()
83
+ random2 = random.random()
84
+
85
+ result.append({
86
+ "filename": filename,
87
+ "folder": folder,
88
+ "category": country,
89
+ "labels": labels,
90
+ "random1": random1,
91
+ "random2": random2,
92
+ "flag": "TRAIN" if random2 < 0.8 else "TEST",
93
+ })
94
+
95
+ df = pd.DataFrame(result)
96
+ pivot_table = pd.pivot_table(df, index=["labels"], values=["filename"], aggfunc="count")
97
+ print(pivot_table)
98
+
99
+ df = df.sort_values(by=["random1"], ascending=False)
100
+ df.to_excel(
101
+ file_dir / "dataset.xlsx",
102
+ index=False,
103
+ # encoding="utf_8_sig"
104
+ )
105
+
106
+ return
107
+
108
+
109
+ def split_dataset(args):
110
+ """分割训练集, 测试集"""
111
+ file_dir = Path(args.file_dir)
112
+ file_dir.mkdir(exist_ok=True)
113
+
114
+ df = pd.read_excel(file_dir / "dataset.xlsx")
115
+
116
+ train = list()
117
+ test = list()
118
+
119
+ for i, row in df.iterrows():
120
+ flag = row["flag"]
121
+ if flag == "TRAIN":
122
+ train.append(row)
123
+ else:
124
+ test.append(row)
125
+
126
+ train = pd.DataFrame(train)
127
+ train.to_excel(
128
+ args.train_dataset,
129
+ index=False,
130
+ # encoding="utf_8_sig"
131
+ )
132
+ test = pd.DataFrame(test)
133
+ test.to_excel(
134
+ args.valid_dataset,
135
+ index=False,
136
+ # encoding="utf_8_sig"
137
+ )
138
+
139
+ return
140
+
141
+
142
+ def main():
143
+ args = get_args()
144
+ get_dataset(args)
145
+ split_dataset(args)
146
+ return
147
+
148
+
149
+ if __name__ == "__main__":
150
+ main()
examples/vm_sound_classification/step_2_make_vocabulary.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import sys
7
+
8
+ pwd = os.path.abspath(os.path.dirname(__file__))
9
+ sys.path.append(os.path.join(pwd, "../../"))
10
+
11
+ import pandas as pd
12
+
13
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
14
+
15
+
16
+ def get_args():
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
19
+
20
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
21
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
22
+
23
+ args = parser.parse_args()
24
+ return args
25
+
26
+
27
+ def main():
28
+ args = get_args()
29
+
30
+ train_dataset = pd.read_excel(args.train_dataset)
31
+ valid_dataset = pd.read_excel(args.valid_dataset)
32
+
33
+ vocabulary = Vocabulary()
34
+
35
+ # train
36
+ for i, row in train_dataset.iterrows():
37
+ label = row["labels"]
38
+ vocabulary.add_token_to_namespace(label, namespace="labels")
39
+
40
+ # valid
41
+ for i, row in valid_dataset.iterrows():
42
+ label = row["labels"]
43
+ vocabulary.add_token_to_namespace(label, namespace="labels")
44
+
45
+ vocabulary.save_to_files(args.vocabulary_dir)
46
+
47
+ return
48
+
49
+
50
+ if __name__ == "__main__":
51
+ main()
examples/vm_sound_classification/step_3_train_model.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from collections import defaultdict
5
+ import json
6
+ import logging
7
+ from logging.handlers import TimedRotatingFileHandler
8
+ import os
9
+ import platform
10
+ from pathlib import Path
11
+ import random
12
+ import sys
13
+ import shutil
14
+ from typing import List
15
+
16
+ pwd = os.path.abspath(os.path.dirname(__file__))
17
+ sys.path.append(os.path.join(pwd, "../../"))
18
+
19
+ import numpy as np
20
+ import torch
21
+ from torch.utils.data.dataloader import DataLoader
22
+ from tqdm import tqdm
23
+
24
+ from toolbox.torch.modules.loss import FocalLoss, HingeLoss, HingeLinear
25
+ from toolbox.torch.training.metrics.categorical_accuracy import CategoricalAccuracy
26
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
27
+ from toolbox.torch.utils.data.dataset.wave_classifier_excel_dataset import WaveClassifierExcelDataset
28
+ from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel
29
+ from toolbox.torchaudio.models.cnn_audio_classifier.configuration_cnn_audio_classifier import CnnAudioClassifierConfig
30
+
31
+
32
+ def get_args():
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
35
+
36
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
37
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
38
+
39
+ parser.add_argument("--max_epochs", default=100, type=int)
40
+
41
+ parser.add_argument("--batch_size", default=64, type=int)
42
+ parser.add_argument("--learning_rate", default=1e-3, type=float)
43
+ parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
44
+ parser.add_argument("--patience", default=5, type=int)
45
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
46
+ parser.add_argument("--seed", default=0, type=int)
47
+
48
+ parser.add_argument("--config_file", default="conv2d_classifier.yaml", type=str)
49
+
50
+ args = parser.parse_args()
51
+ return args
52
+
53
+
54
+ def logging_config(file_dir: str):
55
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
56
+
57
+ logging.basicConfig(format=fmt,
58
+ datefmt="%m/%d/%Y %H:%M:%S",
59
+ level=logging.DEBUG)
60
+ file_handler = TimedRotatingFileHandler(
61
+ filename=os.path.join(file_dir, "main.log"),
62
+ encoding="utf-8",
63
+ when="D",
64
+ interval=1,
65
+ backupCount=7
66
+ )
67
+ file_handler.setLevel(logging.INFO)
68
+ file_handler.setFormatter(logging.Formatter(fmt))
69
+ logger = logging.getLogger(__name__)
70
+ logger.addHandler(file_handler)
71
+
72
+ return logger
73
+
74
+
75
+ class CollateFunction(object):
76
+ def __init__(self):
77
+ pass
78
+
79
+ def __call__(self, batch: List[dict]):
80
+ array_list = list()
81
+ label_list = list()
82
+ for sample in batch:
83
+ array = sample["waveform"]
84
+ label = sample["label"]
85
+
86
+ array_list.append(array)
87
+ label_list.append(label)
88
+
89
+ array_list = torch.stack(array_list)
90
+ label_list = torch.stack(label_list)
91
+ return array_list, label_list
92
+
93
+
94
+ collate_fn = CollateFunction()
95
+
96
+
97
+ def main():
98
+ args = get_args()
99
+
100
+ serialization_dir = Path(args.serialization_dir)
101
+ serialization_dir.mkdir(parents=True, exist_ok=True)
102
+
103
+ logger = logging_config(serialization_dir)
104
+
105
+ random.seed(args.seed)
106
+ np.random.seed(args.seed)
107
+ torch.manual_seed(args.seed)
108
+ logger.info("set seed: {}".format(args.seed))
109
+
110
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
111
+ n_gpu = torch.cuda.device_count()
112
+ logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
113
+
114
+ vocabulary = Vocabulary.from_files(args.vocabulary_dir)
115
+
116
+ # datasets
117
+ logger.info("prepare datasets")
118
+ train_dataset = WaveClassifierExcelDataset(
119
+ vocab=vocabulary,
120
+ excel_file=args.train_dataset,
121
+ category=None,
122
+ category_field="category",
123
+ label_field="labels",
124
+ expected_sample_rate=8000,
125
+ max_wave_value=32768.0,
126
+ )
127
+ valid_dataset = WaveClassifierExcelDataset(
128
+ vocab=vocabulary,
129
+ excel_file=args.valid_dataset,
130
+ category=None,
131
+ category_field="category",
132
+ label_field="labels",
133
+ expected_sample_rate=8000,
134
+ max_wave_value=32768.0,
135
+ )
136
+ train_data_loader = DataLoader(
137
+ dataset=train_dataset,
138
+ batch_size=args.batch_size,
139
+ shuffle=True,
140
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
141
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
142
+ collate_fn=collate_fn,
143
+ pin_memory=False,
144
+ # prefetch_factor=64,
145
+ )
146
+ valid_data_loader = DataLoader(
147
+ dataset=valid_dataset,
148
+ batch_size=args.batch_size,
149
+ shuffle=True,
150
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
151
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
152
+ collate_fn=collate_fn,
153
+ pin_memory=False,
154
+ # prefetch_factor=64,
155
+ )
156
+
157
+ # models
158
+ logger.info("prepare models")
159
+ config = CnnAudioClassifierConfig.from_pretrained(
160
+ pretrained_model_name_or_path=args.config_file,
161
+ # num_labels=vocabulary.get_vocab_size(namespace="labels")
162
+ )
163
+ if not config.cls_head_param["num_labels"] == vocabulary.get_vocab_size(namespace="labels"):
164
+ raise AssertionError
165
+ model = WaveClassifierPretrainedModel(
166
+ config=config,
167
+ )
168
+ model.to(device)
169
+ model.train()
170
+
171
+ # optimizer
172
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
173
+ param_optimizer = model.parameters()
174
+ optimizer = torch.optim.Adam(
175
+ param_optimizer,
176
+ lr=args.learning_rate,
177
+ )
178
+ # lr_scheduler = torch.optim.lr_scheduler.StepLR(
179
+ # optimizer,
180
+ # step_size=2000
181
+ # )
182
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
183
+ optimizer,
184
+ milestones=[10000, 20000, 30000], gamma=0.5
185
+ )
186
+ focal_loss = FocalLoss(
187
+ num_classes=vocabulary.get_vocab_size(namespace="labels"),
188
+ reduction="mean",
189
+ )
190
+ categorical_accuracy = CategoricalAccuracy()
191
+
192
+ # training loop
193
+ logger.info("training")
194
+
195
+ training_loss = 10000000000
196
+ training_accuracy = 0.
197
+ evaluation_loss = 10000000000
198
+ evaluation_accuracy = 0.
199
+
200
+ model_list = list()
201
+ best_idx_epoch = None
202
+ best_accuracy = None
203
+ patience_count = 0
204
+
205
+ for idx_epoch in range(args.max_epochs):
206
+ categorical_accuracy.reset()
207
+ total_loss = 0.
208
+ total_examples = 0.
209
+ progress_bar = tqdm(
210
+ total=len(train_data_loader),
211
+ desc="Training; epoch: {}".format(idx_epoch),
212
+ )
213
+ for batch in train_data_loader:
214
+ input_ids, label_ids = batch
215
+ input_ids = input_ids.to(device)
216
+ label_ids: torch.LongTensor = label_ids.to(device).long()
217
+
218
+ logits = model.forward(input_ids)
219
+ loss = focal_loss.forward(logits, label_ids.view(-1))
220
+ categorical_accuracy(logits, label_ids)
221
+
222
+ total_loss += loss.item()
223
+ total_examples += input_ids.size(0)
224
+
225
+ optimizer.zero_grad()
226
+ loss.backward()
227
+ optimizer.step()
228
+ lr_scheduler.step()
229
+
230
+ training_loss = total_loss / total_examples
231
+ training_loss = round(training_loss, 4)
232
+ training_accuracy = categorical_accuracy.get_metric()["accuracy"]
233
+ training_accuracy = round(training_accuracy, 4)
234
+
235
+ progress_bar.update(1)
236
+ progress_bar.set_postfix({
237
+ "training_loss": training_loss,
238
+ "training_accuracy": training_accuracy,
239
+ })
240
+
241
+ categorical_accuracy.reset()
242
+ total_loss = 0.
243
+ total_examples = 0.
244
+ progress_bar = tqdm(
245
+ total=len(valid_data_loader),
246
+ desc="Evaluation; epoch: {}".format(idx_epoch),
247
+ )
248
+ for batch in valid_data_loader:
249
+ input_ids, label_ids = batch
250
+ input_ids = input_ids.to(device)
251
+ label_ids: torch.LongTensor = label_ids.to(device).long()
252
+
253
+ with torch.no_grad():
254
+ logits = model.forward(input_ids)
255
+ loss = focal_loss.forward(logits, label_ids.view(-1))
256
+ categorical_accuracy(logits, label_ids)
257
+
258
+ total_loss += loss.item()
259
+ total_examples += input_ids.size(0)
260
+
261
+ evaluation_loss = total_loss / total_examples
262
+ evaluation_loss = round(evaluation_loss, 4)
263
+ evaluation_accuracy = categorical_accuracy.get_metric()["accuracy"]
264
+ evaluation_accuracy = round(evaluation_accuracy, 4)
265
+
266
+ progress_bar.update(1)
267
+ progress_bar.set_postfix({
268
+ "evaluation_loss": evaluation_loss,
269
+ "evaluation_accuracy": evaluation_accuracy,
270
+ })
271
+
272
+ # save path
273
+ epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
274
+ epoch_dir.mkdir(parents=True, exist_ok=False)
275
+
276
+ # save models
277
+ model.save_pretrained(epoch_dir.as_posix())
278
+
279
+ model_list.append(epoch_dir)
280
+ if len(model_list) >= args.num_serialized_models_to_keep:
281
+ model_to_delete: Path = model_list.pop(0)
282
+ shutil.rmtree(model_to_delete.as_posix())
283
+
284
+ # save metric
285
+ if best_accuracy is None:
286
+ best_idx_epoch = idx_epoch
287
+ best_accuracy = evaluation_accuracy
288
+ elif evaluation_accuracy > best_accuracy:
289
+ best_idx_epoch = idx_epoch
290
+ best_accuracy = evaluation_accuracy
291
+ else:
292
+ pass
293
+
294
+ metrics = {
295
+ "idx_epoch": idx_epoch,
296
+ "best_idx_epoch": best_idx_epoch,
297
+ "best_accuracy": best_accuracy,
298
+ "training_loss": training_loss,
299
+ "training_accuracy": training_accuracy,
300
+ "evaluation_loss": evaluation_loss,
301
+ "evaluation_accuracy": evaluation_accuracy,
302
+ "learning_rate": optimizer.param_groups[0]['lr'],
303
+ }
304
+ metrics_filename = epoch_dir / "metrics_epoch.json"
305
+ with open(metrics_filename, "w", encoding="utf-8") as f:
306
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
307
+
308
+ # save best
309
+ best_dir = serialization_dir / "best"
310
+ if best_idx_epoch == idx_epoch:
311
+ if best_dir.exists():
312
+ shutil.rmtree(best_dir)
313
+ shutil.copytree(epoch_dir, best_dir)
314
+
315
+ # early stop
316
+ early_stop_flag = False
317
+ if best_idx_epoch == idx_epoch:
318
+ patience_count = 0
319
+ else:
320
+ patience_count += 1
321
+ if patience_count >= args.patience:
322
+ early_stop_flag = True
323
+
324
+ # early stop
325
+ if early_stop_flag:
326
+ break
327
+ return
328
+
329
+
330
+ if __name__ == "__main__":
331
+ main()
examples/vm_sound_classification/step_4_evaluation_model.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from collections import defaultdict
5
+ import json
6
+ import logging
7
+ from logging.handlers import TimedRotatingFileHandler
8
+ import os
9
+ import platform
10
+ from pathlib import Path
11
+ import sys
12
+ import shutil
13
+ from typing import List
14
+
15
+ pwd = os.path.abspath(os.path.dirname(__file__))
16
+ sys.path.append(os.path.join(pwd, "../../"))
17
+
18
+ import pandas as pd
19
+ from scipy.io import wavfile
20
+ import torch
21
+ from tqdm import tqdm
22
+
23
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
24
+ from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel
25
+
26
+
27
+ def get_args():
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument("--dataset", default="dataset.xlsx", type=str)
30
+
31
+ parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
32
+ parser.add_argument("--model_dir", default="best", type=str)
33
+
34
+ parser.add_argument("--output_file", default="evaluation.xlsx", type=str)
35
+
36
+ args = parser.parse_args()
37
+ return args
38
+
39
+
40
+ def logging_config():
41
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
42
+
43
+ logging.basicConfig(format=fmt,
44
+ datefmt="%m/%d/%Y %H:%M:%S",
45
+ level=logging.DEBUG)
46
+ stream_handler = logging.StreamHandler()
47
+ stream_handler.setLevel(logging.INFO)
48
+ stream_handler.setFormatter(logging.Formatter(fmt))
49
+
50
+ logger = logging.getLogger(__name__)
51
+
52
+ return logger
53
+
54
+
55
+ def main():
56
+ args = get_args()
57
+
58
+ logger = logging_config()
59
+
60
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
+ n_gpu = torch.cuda.device_count()
62
+ logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
63
+
64
+ logger.info("prepare vocabulary, model")
65
+ vocabulary = Vocabulary.from_files(args.vocabulary_dir)
66
+
67
+ model = WaveClassifierPretrainedModel.from_pretrained(
68
+ pretrained_model_name_or_path=args.model_dir,
69
+ )
70
+ model.to(device)
71
+ model.eval()
72
+
73
+ logger.info("read excel")
74
+ df = pd.read_excel(args.dataset)
75
+ result = list()
76
+
77
+ total_correct = 0
78
+ total_examples = 0
79
+
80
+ progress_bar = tqdm(total=len(df), desc="Evaluation")
81
+ for i, row in df.iterrows():
82
+ filename = row["filename"]
83
+ ground_true = row["labels"]
84
+
85
+ sample_rate, waveform = wavfile.read(filename)
86
+ waveform = waveform / (1 << 15)
87
+ waveform = torch.tensor(waveform, dtype=torch.float32)
88
+ waveform = torch.unsqueeze(waveform, dim=0)
89
+ waveform = waveform.to(device)
90
+
91
+ with torch.no_grad():
92
+ logits = model.forward(waveform)
93
+ probs = torch.nn.functional.softmax(logits, dim=-1)
94
+ label_idx = torch.argmax(probs, dim=-1)
95
+
96
+ label_idx = label_idx.cpu()
97
+ probs = probs.cpu()
98
+
99
+ label_idx = label_idx.numpy()[0]
100
+ label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
101
+ prob = probs[0][label_idx].numpy()
102
+
103
+ correct = 1 if label_str == ground_true else 0
104
+ row_ = dict(row)
105
+ row_["predict"] = label_str
106
+ row_["prob"] = prob
107
+ row_["correct"] = correct
108
+ result.append(row_)
109
+
110
+ total_examples += 1
111
+ total_correct += correct
112
+ accuracy = total_correct / total_examples
113
+
114
+ progress_bar.update(1)
115
+ progress_bar.set_postfix({
116
+ "accuracy": accuracy,
117
+ })
118
+
119
+ result = pd.DataFrame(result)
120
+ result.to_excel(
121
+ args.output_file,
122
+ index=False
123
+ )
124
+ return
125
+
126
+
127
+ if __name__ == '__main__':
128
+ main()
examples/vm_sound_classification/step_5_export_models.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from collections import defaultdict
5
+ import json
6
+ import logging
7
+ from logging.handlers import TimedRotatingFileHandler
8
+ import os
9
+ import platform
10
+ from pathlib import Path
11
+ import sys
12
+ import shutil
13
+ from typing import List
14
+
15
+ pwd = os.path.abspath(os.path.dirname(__file__))
16
+ sys.path.append(os.path.join(pwd, "../../"))
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
22
+ from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel
23
+
24
+
25
+ def get_args():
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
28
+ parser.add_argument("--model_dir", default="best", type=str)
29
+
30
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
31
+
32
+ args = parser.parse_args()
33
+ return args
34
+
35
+
36
+ def logging_config():
37
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
38
+
39
+ logging.basicConfig(format=fmt,
40
+ datefmt="%m/%d/%Y %H:%M:%S",
41
+ level=logging.DEBUG)
42
+ stream_handler = logging.StreamHandler()
43
+ stream_handler.setLevel(logging.INFO)
44
+ stream_handler.setFormatter(logging.Formatter(fmt))
45
+
46
+ logger = logging.getLogger(__name__)
47
+
48
+ return logger
49
+
50
+
51
+ def main():
52
+ args = get_args()
53
+
54
+ serialization_dir = Path(args.serialization_dir)
55
+
56
+ logger = logging_config()
57
+
58
+ logger.info("export models on CPU")
59
+ device = torch.device("cpu")
60
+
61
+ logger.info("prepare vocabulary, model")
62
+ vocabulary = Vocabulary.from_files(args.vocabulary_dir)
63
+
64
+ model = WaveClassifierPretrainedModel.from_pretrained(
65
+ pretrained_model_name_or_path=args.model_dir,
66
+ num_labels=vocabulary.get_vocab_size(namespace="labels")
67
+ )
68
+ model.to(device)
69
+ model.eval()
70
+
71
+ waveform = 0 + 25 * np.random.randn(16000,)
72
+ waveform = np.array(waveform, dtype=np.int16)
73
+ waveform = waveform / (1 << 15)
74
+ waveform = torch.tensor(waveform, dtype=torch.float32)
75
+ waveform = torch.unsqueeze(waveform, dim=0)
76
+ waveform = waveform.to(device)
77
+
78
+ logger.info("export jit models")
79
+ example_inputs = (waveform,)
80
+
81
+ # trace model
82
+ trace_model = torch.jit.trace(func=model, example_inputs=example_inputs, strict=False)
83
+ trace_model.save(serialization_dir / "trace_model.zip")
84
+
85
+ # quantization trace model (not work on GPU)
86
+ quantized_model = torch.quantization.quantize_dynamic(
87
+ model, {torch.nn.Linear}, dtype=torch.qint8
88
+ )
89
+ trace_quant_model = torch.jit.trace(func=quantized_model, example_inputs=example_inputs, strict=False)
90
+ trace_quant_model.save(serialization_dir / "trace_quant_model.zip")
91
+
92
+ # script model
93
+ script_model = torch.jit.script(obj=model)
94
+ script_model.save(serialization_dir / "script_model.zip")
95
+
96
+ # quantization script model (not work on GPU)
97
+ quantized_model = torch.quantization.quantize_dynamic(
98
+ model, {torch.nn.Linear}, dtype=torch.qint8
99
+ )
100
+ script_quant_model = torch.jit.script(quantized_model)
101
+ script_quant_model.save(serialization_dir / "script_quant_model.zip")
102
+ return
103
+
104
+
105
+ if __name__ == '__main__':
106
+ main()
examples/vm_sound_classification/step_6_infer.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import shutil
7
+ import sys
8
+ import tempfile
9
+ import zipfile
10
+
11
+ pwd = os.path.abspath(os.path.dirname(__file__))
12
+ sys.path.append(os.path.join(pwd, "../../"))
13
+
14
+ from scipy.io import wavfile
15
+ import torch
16
+
17
+ from project_settings import project_path
18
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
19
+
20
+
21
+ def get_args():
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument(
24
+ "--model_file",
25
+ default=(project_path / "trained_models/vm_sound_classification3.zip").as_posix(),
26
+ type=str
27
+ )
28
+ parser.add_argument(
29
+ "--wav_file",
30
+ default=r"C:\Users\tianx\Desktop\a073d03d-d280-46df-9b2d-d904965f4500_zh-CN_h3f25ivhb0c0_1719478037746.wav",
31
+ type=str
32
+ )
33
+
34
+ parser.add_argument("--device", default="cpu", type=str)
35
+
36
+ args = parser.parse_args()
37
+ return args
38
+
39
+
40
+ def main():
41
+ args = get_args()
42
+
43
+ model_file = Path(args.model_file)
44
+
45
+ device = torch.device(args.device)
46
+
47
+ with zipfile.ZipFile(model_file, "r") as f_zip:
48
+ out_root = Path(tempfile.gettempdir()) / "vm_sound_classification"
49
+ print(out_root.as_posix())
50
+ if out_root.exists():
51
+ shutil.rmtree(out_root.as_posix())
52
+ out_root.mkdir(parents=True, exist_ok=True)
53
+ f_zip.extractall(path=out_root)
54
+
55
+ tgt_path = out_root / model_file.stem
56
+ jit_model_file = tgt_path / "trace_model.zip"
57
+ vocab_path = tgt_path / "vocabulary"
58
+
59
+ with open(jit_model_file.as_posix(), "rb") as f:
60
+ model = torch.jit.load(f)
61
+ model.to(device)
62
+ model.eval()
63
+ vocabulary = Vocabulary.from_files(vocab_path.as_posix())
64
+
65
+ # infer
66
+ sample_rate, waveform = wavfile.read(args.wav_file)
67
+ waveform = waveform[:16000]
68
+ waveform = waveform / (1 << 15)
69
+ waveform = torch.tensor(waveform, dtype=torch.float32)
70
+ waveform = torch.unsqueeze(waveform, dim=0)
71
+ waveform = waveform.to(device)
72
+
73
+ with torch.no_grad():
74
+ logits = model.forward(waveform)
75
+ probs = torch.nn.functional.softmax(logits, dim=-1)
76
+ label_idx = torch.argmax(probs, dim=-1)
77
+
78
+ label_idx = label_idx.cpu()
79
+ probs = probs.cpu()
80
+
81
+ label_idx = label_idx.numpy()[0]
82
+ prob = probs.numpy()[0][label_idx]
83
+
84
+ label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
85
+ print(label_str)
86
+ print(prob)
87
+ return
88
+
89
+
90
+ if __name__ == '__main__':
91
+ main()
examples/vm_sound_classification/step_7_test_model.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import shutil
7
+ import sys
8
+ import tempfile
9
+ import zipfile
10
+
11
+ pwd = os.path.abspath(os.path.dirname(__file__))
12
+ sys.path.append(os.path.join(pwd, "../../"))
13
+
14
+ from scipy.io import wavfile
15
+ import torch
16
+
17
+ from project_settings import project_path
18
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
19
+ from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel
20
+
21
+
22
+ def get_args():
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument(
25
+ "--model_file",
26
+ default=(project_path / "trained_models/vm_sound_classification3.zip").as_posix(),
27
+ type=str
28
+ )
29
+ parser.add_argument(
30
+ "--wav_file",
31
+ default=r"C:\Users\tianx\Desktop\4b284733-0be3-4a48-abbb-615b32ac44b7_6ndddc2szlh0.wav",
32
+ type=str
33
+ )
34
+
35
+ parser.add_argument("--device", default="cpu", type=str)
36
+
37
+ args = parser.parse_args()
38
+ return args
39
+
40
+
41
+ def main():
42
+ args = get_args()
43
+
44
+ model_file = Path(args.model_file)
45
+
46
+ device = torch.device(args.device)
47
+
48
+ with zipfile.ZipFile(model_file, "r") as f_zip:
49
+ out_root = Path(tempfile.gettempdir()) / "vm_sound_classification"
50
+ print(out_root)
51
+ if out_root.exists():
52
+ shutil.rmtree(out_root.as_posix())
53
+ out_root.mkdir(parents=True, exist_ok=True)
54
+ f_zip.extractall(path=out_root)
55
+
56
+ tgt_path = out_root / model_file.stem
57
+ vocab_path = tgt_path / "vocabulary"
58
+
59
+ vocabulary = Vocabulary.from_files(vocab_path.as_posix())
60
+
61
+ model = WaveClassifierPretrainedModel.from_pretrained(
62
+ pretrained_model_name_or_path=tgt_path.as_posix(),
63
+ )
64
+ model.to(device)
65
+ model.eval()
66
+
67
+ # infer
68
+ sample_rate, waveform = wavfile.read(args.wav_file)
69
+ waveform = waveform[:16000]
70
+ waveform = waveform / (1 << 15)
71
+ waveform = torch.tensor(waveform, dtype=torch.float32)
72
+ waveform = torch.unsqueeze(waveform, dim=0)
73
+ waveform = waveform.to(device)
74
+ print(waveform.shape)
75
+ with torch.no_grad():
76
+ logits = model.forward(waveform)
77
+ probs = torch.nn.functional.softmax(logits, dim=-1)
78
+ label_idx = torch.argmax(probs, dim=-1)
79
+
80
+ label_idx = label_idx.cpu()
81
+ probs = probs.cpu()
82
+
83
+ label_idx = label_idx.numpy()[0]
84
+ prob = probs.numpy()[0][label_idx]
85
+
86
+ label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
87
+ print(label_str)
88
+ print(prob)
89
+ return
90
+
91
+
92
+ if __name__ == '__main__':
93
+ main()
examples/vm_sound_classification/stop.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ kill -9 `ps -aef | grep 'vm_sound_classification/bin/python3' | grep -v grep | awk '{print $2}' | sed 's/\n/ /'`
examples/vm_sound_classification8/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.10.1
2
+ torchaudio==0.10.1
3
+ fsspec==2022.1.0
4
+ librosa==0.9.2
5
+ pandas==1.1.5
6
+ openpyxl==3.0.9
7
+ xlrd==1.2.0
8
+ tqdm==4.64.1
9
+ overrides==1.9.0
examples/vm_sound_classification8/run.sh ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+ sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification8 \
6
+ --filename_patterns "E:/programmer/asr_datasets/voicemail/wav_finished/en-US/wav_finished/*/*.wav \
7
+ E:/programmer/asr_datasets/voicemail/wav_finished/id-ID/wav_finished/*/*.wav" \
8
+
9
+
10
+ sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification8 \
11
+ --filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
12
+
13
+ sh run.sh --stage 4 --stop_stage 4 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification8 \
14
+ --filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
15
+
16
+ sh run.sh --stage 4 --stop_stage 4 --system_version centos --file_folder_name file_dir --final_model_name vm_sound_classification8 \
17
+ --filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
18
+
19
+
20
+ "
21
+
22
+ END
23
+
24
+
25
+ # sh run.sh --stage -1 --stop_stage 9
26
+ # sh run.sh --stage -1 --stop_stage 5 --system_version centos --file_folder_name task_cnn_voicemail_id_id --final_model_name cnn_voicemail_id_id
27
+ # sh run.sh --stage 3 --stop_stage 4
28
+ # sh run.sh --stage 4 --stop_stage 4
29
+ # sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name task_cnn_voicemail_id_id
30
+
31
+ # params
32
+ system_version="windows";
33
+ verbose=true;
34
+ stage=0 # start from 0 if you need to start from data preparation
35
+ stop_stage=9
36
+
37
+ work_dir="$(pwd)"
38
+ file_folder_name=file_folder_name
39
+ final_model_name=final_model_name
40
+ filename_patterns="/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
41
+ nohup_name=nohup.out
42
+
43
+ country=en-US
44
+
45
+ # model params
46
+ batch_size=64
47
+ max_epochs=200
48
+ save_top_k=10
49
+ patience=5
50
+
51
+
52
+ # parse options
53
+ while true; do
54
+ [ -z "${1:-}" ] && break; # break if there are no arguments
55
+ case "$1" in
56
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
57
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
58
+ old_value="(eval echo \\$$name)";
59
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
60
+ was_bool=true;
61
+ else
62
+ was_bool=false;
63
+ fi
64
+
65
+ # Set the variable to the right value-- the escaped quotes make it work if
66
+ # the option had spaces, like --cmd "queue.pl -sync y"
67
+ eval "${name}=\"$2\"";
68
+
69
+ # Check that Boolean-valued arguments are really Boolean.
70
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
71
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
72
+ exit 1;
73
+ fi
74
+ shift 2;
75
+ ;;
76
+
77
+ *) break;
78
+ esac
79
+ done
80
+
81
+ file_dir="${work_dir}/${file_folder_name}"
82
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
83
+
84
+ train_dataset="${file_dir}/train.xlsx"
85
+ valid_dataset="${file_dir}/valid.xlsx"
86
+ vocabulary_dir="${file_dir}/vocabulary"
87
+
88
+
89
+ $verbose && echo "system_version: ${system_version}"
90
+ $verbose && echo "file_folder_name: ${file_folder_name}"
91
+
92
+ if [ $system_version == "windows" ]; then
93
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/vm_sound_classification/Scripts/python.exe'
94
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
95
+ #source /data/local/bin/vm_sound_classification/bin/activate
96
+ alias python3='/data/local/bin/vm_sound_classification/bin/python3'
97
+ fi
98
+
99
+
100
+ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
101
+ $verbose && echo "stage 0: prepare data"
102
+ cd "${work_dir}" || exit 1
103
+ python3 step_1_prepare_data.py \
104
+ --file_dir "${file_dir}" \
105
+ --filename_patterns "${filename_patterns}" \
106
+ --train_dataset "${train_dataset}" \
107
+ --valid_dataset "${valid_dataset}" \
108
+
109
+ fi
110
+
111
+
112
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
113
+ $verbose && echo "stage 1: make vocabulary"
114
+ cd "${work_dir}" || exit 1
115
+ python3 step_2_make_vocabulary.py \
116
+ --vocabulary_dir "${vocabulary_dir}" \
117
+ --train_dataset "${train_dataset}" \
118
+ --valid_dataset "${valid_dataset}" \
119
+
120
+ fi
121
+
122
+
123
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
124
+ $verbose && echo "stage 2: train global model"
125
+ cd "${work_dir}" || exit 1
126
+ python3 step_3_train_global_model.py \
127
+ --vocabulary_dir "${vocabulary_dir}" \
128
+ --train_dataset "${train_dataset}" \
129
+ --valid_dataset "${valid_dataset}" \
130
+ --serialization_dir "${file_dir}/global_model" \
131
+
132
+ fi
133
+
134
+
135
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
136
+ $verbose && echo "stage 3: train country model"
137
+ cd "${work_dir}" || exit 1
138
+ python3 step_4_train_country_model.py \
139
+ --vocabulary_dir "${vocabulary_dir}" \
140
+ --train_dataset "${train_dataset}" \
141
+ --valid_dataset "${valid_dataset}" \
142
+ --country "${country}" \
143
+ --serialization_dir "${file_dir}/country_model" \
144
+
145
+ fi
146
+
147
+
148
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
149
+ $verbose && echo "stage 4: train union model"
150
+ cd "${work_dir}" || exit 1
151
+ python3 step_5_train_union.py \
152
+ --vocabulary_dir "${vocabulary_dir}" \
153
+ --train_dataset "${train_dataset}" \
154
+ --valid_dataset "${valid_dataset}" \
155
+ --serialization_dir "${file_dir}/union" \
156
+
157
+ fi
examples/vm_sound_classification8/step_1_prepare_data.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from glob import glob
5
+ import json
6
+ import os
7
+ from pathlib import Path
8
+ import random
9
+ import sys
10
+
11
+ pwd = os.path.abspath(os.path.dirname(__file__))
12
+ sys.path.append(os.path.join(pwd, "../../"))
13
+
14
+ import pandas as pd
15
+ from scipy.io import wavfile
16
+ from tqdm import tqdm
17
+
18
+
19
+ def get_args():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--file_dir", default="./", type=str)
22
+ parser.add_argument("--task", default="default", type=str)
23
+ parser.add_argument("--filename_patterns", type=str)
24
+
25
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
26
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
27
+
28
+ args = parser.parse_args()
29
+ return args
30
+
31
+
32
+ def get_dataset(args):
33
+ filename_patterns = args.filename_patterns
34
+ filename_patterns = filename_patterns.split(" ")
35
+ print(filename_patterns)
36
+
37
+ file_dir = Path(args.file_dir)
38
+ file_dir.mkdir(exist_ok=True)
39
+
40
+ global_label_map = {
41
+ "bell": "bell",
42
+ "white_noise": "white_noise",
43
+ "low_white_noise": "white_noise",
44
+ "high_white_noise": "noise",
45
+ "music": "music",
46
+ "mute": "mute",
47
+ "noise": "noise",
48
+ "noise_mute": "noise_mute",
49
+ "voice": "voice",
50
+ "voicemail": "voicemail",
51
+ }
52
+
53
+ country_label_map = {
54
+ "bell": "voicemail",
55
+ "white_noise": "non_voicemail",
56
+ "low_white_noise": "non_voicemail",
57
+ "hight_white_noise": "non_voicemail",
58
+ "music": "non_voicemail",
59
+ "mute": "non_voicemail",
60
+ "noise": "non_voicemail",
61
+ "noise_mute": "non_voicemail",
62
+ "voice": "non_voicemail",
63
+ "voicemail": "voicemail",
64
+ "non_voicemail": "non_voicemail",
65
+ }
66
+
67
+ result = list()
68
+ for filename_pattern in filename_patterns:
69
+ filename_list = glob(filename_pattern)
70
+ for filename in tqdm(filename_list):
71
+ filename = Path(filename)
72
+ sample_rate, signal = wavfile.read(filename.as_posix())
73
+ if len(signal) < sample_rate * 2:
74
+ continue
75
+
76
+ folder = filename.parts[-2]
77
+ country = filename.parts[-4]
78
+
79
+ if folder not in global_label_map.keys():
80
+ continue
81
+ if folder not in country_label_map.keys():
82
+ continue
83
+
84
+ global_label = global_label_map[folder]
85
+ country_label = country_label_map[folder]
86
+
87
+ random1 = random.random()
88
+ random2 = random.random()
89
+
90
+ result.append({
91
+ "filename": filename,
92
+ "folder": folder,
93
+ "category": country,
94
+ "global_labels": global_label,
95
+ "country_labels": country_label,
96
+ "random1": random1,
97
+ "random2": random2,
98
+ "flag": "TRAIN" if random2 < 0.8 else "TEST",
99
+ })
100
+
101
+ df = pd.DataFrame(result)
102
+ pivot_table = pd.pivot_table(df, index=["global_labels"], values=["filename"], aggfunc="count")
103
+ print(pivot_table)
104
+
105
+ df = df.sort_values(by=["random1"], ascending=False)
106
+ df.to_excel(
107
+ file_dir / "dataset.xlsx",
108
+ index=False,
109
+ # encoding="utf_8_sig"
110
+ )
111
+
112
+ return
113
+
114
+
115
+ def split_dataset(args):
116
+ """分割训练集, 测试集"""
117
+ file_dir = Path(args.file_dir)
118
+ file_dir.mkdir(exist_ok=True)
119
+
120
+ df = pd.read_excel(file_dir / "dataset.xlsx")
121
+
122
+ train = list()
123
+ test = list()
124
+
125
+ for i, row in df.iterrows():
126
+ flag = row["flag"]
127
+ if flag == "TRAIN":
128
+ train.append(row)
129
+ else:
130
+ test.append(row)
131
+
132
+ train = pd.DataFrame(train)
133
+ train.to_excel(
134
+ args.train_dataset,
135
+ index=False,
136
+ # encoding="utf_8_sig"
137
+ )
138
+ test = pd.DataFrame(test)
139
+ test.to_excel(
140
+ args.valid_dataset,
141
+ index=False,
142
+ # encoding="utf_8_sig"
143
+ )
144
+
145
+ return
146
+
147
+
148
+ def main():
149
+ args = get_args()
150
+ get_dataset(args)
151
+ split_dataset(args)
152
+ return
153
+
154
+
155
+ if __name__ == "__main__":
156
+ main()
examples/vm_sound_classification8/step_2_make_vocabulary.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import sys
7
+
8
+ pwd = os.path.abspath(os.path.dirname(__file__))
9
+ sys.path.append(os.path.join(pwd, "../../"))
10
+
11
+ import pandas as pd
12
+
13
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
14
+
15
+
16
+ def get_args():
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
19
+
20
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
21
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
22
+
23
+ args = parser.parse_args()
24
+ return args
25
+
26
+
27
+ def main():
28
+ args = get_args()
29
+
30
+ train_dataset = pd.read_excel(args.train_dataset)
31
+ valid_dataset = pd.read_excel(args.valid_dataset)
32
+
33
+ # non_padded_namespaces
34
+ category_set = set()
35
+ for i, row in train_dataset.iterrows():
36
+ category = row["category"]
37
+ category_set.add(category)
38
+
39
+ for i, row in valid_dataset.iterrows():
40
+ category = row["category"]
41
+ category_set.add(category)
42
+
43
+ vocabulary = Vocabulary(non_padded_namespaces=["global_labels", *list(category_set)])
44
+
45
+ # train
46
+ for i, row in train_dataset.iterrows():
47
+ global_labels = row["global_labels"]
48
+ country_labels = row["country_labels"]
49
+ category = row["category"]
50
+
51
+ vocabulary.add_token_to_namespace(global_labels, "global_labels")
52
+ vocabulary.add_token_to_namespace(country_labels, category)
53
+
54
+ # valid
55
+ for i, row in valid_dataset.iterrows():
56
+ global_labels = row["global_labels"]
57
+ country_labels = row["country_labels"]
58
+ category = row["category"]
59
+
60
+ vocabulary.add_token_to_namespace(global_labels, "global_labels")
61
+ vocabulary.add_token_to_namespace(country_labels, category)
62
+
63
+ vocabulary.save_to_files(args.vocabulary_dir)
64
+
65
+ return
66
+
67
+
68
+ if __name__ == "__main__":
69
+ main()
examples/vm_sound_classification8/step_3_train_global_model.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ 之前的代码达到准确率0.8423
5
+ 此代码达到准确率0.8379
6
+ 此代码可行.
7
+ """
8
+ import argparse
9
+ import copy
10
+ import json
11
+ import logging
12
+ from logging.handlers import TimedRotatingFileHandler
13
+ import os
14
+ from pathlib import Path
15
+ import platform
16
+ import sys
17
+ from typing import List
18
+
19
+ pwd = os.path.abspath(os.path.dirname(__file__))
20
+ sys.path.append(os.path.join(pwd, "../../"))
21
+
22
+ import torch
23
+ from torch.utils.data.dataloader import DataLoader
24
+ from tqdm import tqdm
25
+
26
+ from toolbox.torch.modules.loss import FocalLoss, HingeLoss, HingeLinear
27
+ from toolbox.torch.training.metrics.categorical_accuracy import CategoricalAccuracy
28
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
29
+ from toolbox.torch.utils.data.dataset.wave_classifier_excel_dataset import WaveClassifierExcelDataset
30
+ from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveEncoder, ClsHead, WaveClassifier
31
+
32
+
33
+ def get_args():
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
36
+
37
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
38
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
39
+
40
+ parser.add_argument("--max_epochs", default=100, type=int)
41
+ parser.add_argument("--batch_size", default=64, type=int)
42
+ parser.add_argument("--learning_rate", default=1e-3, type=float)
43
+ parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
44
+ parser.add_argument("--patience", default=5, type=int)
45
+ parser.add_argument("--serialization_dir", default="global_classifier", type=str)
46
+ parser.add_argument("--seed", default=0, type=int)
47
+
48
+ args = parser.parse_args()
49
+ return args
50
+
51
+
52
+ def logging_config(file_dir: str):
53
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
54
+
55
+ logging.basicConfig(format=fmt,
56
+ datefmt="%m/%d/%Y %H:%M:%S",
57
+ level=logging.DEBUG)
58
+ file_handler = TimedRotatingFileHandler(
59
+ filename=os.path.join(file_dir, "main.log"),
60
+ encoding="utf-8",
61
+ when="D",
62
+ interval=1,
63
+ backupCount=7
64
+ )
65
+ file_handler.setLevel(logging.INFO)
66
+ file_handler.setFormatter(logging.Formatter(fmt))
67
+ logger = logging.getLogger(__name__)
68
+ logger.addHandler(file_handler)
69
+
70
+ return logger
71
+
72
+
73
+ class CollateFunction(object):
74
+ def __init__(self):
75
+ pass
76
+
77
+ def __call__(self, batch: List[dict]):
78
+ array_list = list()
79
+ label_list = list()
80
+ for sample in batch:
81
+ array = sample["waveform"]
82
+ label = sample["label"]
83
+
84
+ array_list.append(array)
85
+ label_list.append(label)
86
+
87
+ array_list = torch.stack(array_list)
88
+ label_list = torch.stack(label_list)
89
+ return array_list, label_list
90
+
91
+
92
+ collate_fn = CollateFunction()
93
+
94
+
95
+ def main():
96
+ args = get_args()
97
+
98
+ serialization_dir = Path(args.serialization_dir)
99
+ serialization_dir.mkdir(parents=True, exist_ok=True)
100
+
101
+ logger = logging_config(args.serialization_dir)
102
+
103
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
104
+ n_gpu = torch.cuda.device_count()
105
+ logger.info("GPU available: {}; device: {}".format(n_gpu, device))
106
+
107
+ vocabulary = Vocabulary.from_files(args.vocabulary_dir)
108
+
109
+ # datasets
110
+ train_dataset = WaveClassifierExcelDataset(
111
+ vocab=vocabulary,
112
+ excel_file=args.train_dataset,
113
+ category=None,
114
+ category_field="category",
115
+ label_field="global_labels",
116
+ expected_sample_rate=8000,
117
+ max_wave_value=32768.0,
118
+ )
119
+ valid_dataset = WaveClassifierExcelDataset(
120
+ vocab=vocabulary,
121
+ excel_file=args.valid_dataset,
122
+ category=None,
123
+ category_field="category",
124
+ label_field="global_labels",
125
+ expected_sample_rate=8000,
126
+ max_wave_value=32768.0,
127
+ )
128
+
129
+ train_data_loader = DataLoader(
130
+ dataset=train_dataset,
131
+ batch_size=args.batch_size,
132
+ shuffle=True,
133
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
134
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count(),
135
+ collate_fn=collate_fn,
136
+ pin_memory=False,
137
+ # prefetch_factor=64,
138
+ )
139
+ valid_data_loader = DataLoader(
140
+ dataset=valid_dataset,
141
+ batch_size=args.batch_size,
142
+ shuffle=True,
143
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
144
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count(),
145
+ collate_fn=collate_fn,
146
+ pin_memory=False,
147
+ # prefetch_factor=64,
148
+ )
149
+
150
+ # models - classifier
151
+ wave_encoder = WaveEncoder(
152
+ conv1d_block_param_list=[
153
+ {
154
+ 'batch_norm': True,
155
+ 'in_channels': 80,
156
+ 'out_channels': 16,
157
+ 'kernel_size': 3,
158
+ 'stride': 3,
159
+ # 'padding': 'same',
160
+ 'activation': 'relu',
161
+ 'dropout': 0.1,
162
+ },
163
+ {
164
+ # 'batch_norm': True,
165
+ 'in_channels': 16,
166
+ 'out_channels': 16,
167
+ 'kernel_size': 3,
168
+ 'stride': 3,
169
+ # 'padding': 'same',
170
+ 'activation': 'relu',
171
+ 'dropout': 0.1,
172
+ },
173
+ {
174
+ # 'batch_norm': True,
175
+ 'in_channels': 16,
176
+ 'out_channels': 16,
177
+ 'kernel_size': 3,
178
+ 'stride': 3,
179
+ # 'padding': 'same',
180
+ 'activation': 'relu',
181
+ 'dropout': 0.1,
182
+ },
183
+ ],
184
+ mel_spectrogram_param={
185
+ "sample_rate": 8000,
186
+ "n_fft": 512,
187
+ "win_length": 200,
188
+ "hop_length": 80,
189
+ "f_min": 10,
190
+ "f_max": 3800,
191
+ "window_fn": "hamming",
192
+ "n_mels": 80,
193
+ }
194
+ )
195
+ cls_head = ClsHead(
196
+ input_dim=16,
197
+ num_layers=2,
198
+ hidden_dims=[32, 16],
199
+ activations="relu",
200
+ dropout=0.1,
201
+ num_labels=vocabulary.get_vocab_size(namespace="global_labels")
202
+ )
203
+ model = WaveClassifier(
204
+ wave_encoder=wave_encoder,
205
+ cls_head=cls_head,
206
+ )
207
+ model.to(device)
208
+
209
+ # optimizer
210
+ optimizer = torch.optim.Adam(
211
+ model.parameters(),
212
+ lr=args.learning_rate
213
+ )
214
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(
215
+ optimizer,
216
+ step_size=30000
217
+ )
218
+ focal_loss = FocalLoss(
219
+ num_classes=vocabulary.get_vocab_size(namespace="global_labels"),
220
+ reduction="mean",
221
+ )
222
+ categorical_accuracy = CategoricalAccuracy()
223
+
224
+ # training
225
+ best_idx_epoch: int = None
226
+ best_accuracy: float = None
227
+ patience_count = 0
228
+ global_step = 0
229
+ model_filename_list = list()
230
+ for idx_epoch in range(args.max_epochs):
231
+
232
+ # training
233
+ model.train()
234
+ total_loss = 0
235
+ total_examples = 0
236
+ for step, batch in enumerate(tqdm(train_data_loader, desc="Epoch={} (training)".format(idx_epoch))):
237
+ input_ids, label_ids = batch
238
+ input_ids = input_ids.to(device)
239
+ label_ids: torch.LongTensor = label_ids.to(device).long()
240
+
241
+ logits = model.forward(input_ids)
242
+ loss = focal_loss.forward(logits, label_ids.view(-1))
243
+ categorical_accuracy(logits, label_ids)
244
+
245
+ total_loss += loss.item()
246
+ total_examples += input_ids.size(0)
247
+
248
+ optimizer.zero_grad()
249
+ loss.backward()
250
+ optimizer.step()
251
+ lr_scheduler.step()
252
+
253
+ global_step += 1
254
+ training_loss = total_loss / total_examples
255
+ training_loss = round(training_loss, 4)
256
+ training_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
257
+ training_accuracy = round(training_accuracy, 4)
258
+ logger.info("Epoch: {}; training_loss: {}; training_accuracy: {}".format(
259
+ idx_epoch, training_loss, training_accuracy
260
+ ))
261
+
262
+ # evaluation
263
+ model.eval()
264
+ total_loss = 0
265
+ total_examples = 0
266
+ for step, batch in enumerate(tqdm(valid_data_loader, desc="Epoch={} (evaluation)".format(idx_epoch))):
267
+ input_ids, label_ids = batch
268
+ input_ids = input_ids.to(device)
269
+ label_ids: torch.LongTensor = label_ids.to(device).long()
270
+
271
+ with torch.no_grad():
272
+ logits = model.forward(input_ids)
273
+ loss = focal_loss.forward(logits, label_ids.view(-1))
274
+ categorical_accuracy(logits, label_ids)
275
+
276
+ total_loss += loss.item()
277
+ total_examples += input_ids.size(0)
278
+
279
+ evaluation_loss = total_loss / total_examples
280
+ evaluation_loss = round(evaluation_loss, 4)
281
+ evaluation_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
282
+ evaluation_accuracy = round(evaluation_accuracy, 4)
283
+ logger.info("Epoch: {}; evaluation_loss: {}; evaluation_accuracy: {}".format(
284
+ idx_epoch, evaluation_loss, evaluation_accuracy
285
+ ))
286
+
287
+ # save metric
288
+ metrics = {
289
+ "training_loss": training_loss,
290
+ "training_accuracy": training_accuracy,
291
+ "evaluation_loss": evaluation_loss,
292
+ "evaluation_accuracy": evaluation_accuracy,
293
+ "best_idx_epoch": best_idx_epoch,
294
+ "best_accuracy": best_accuracy,
295
+ }
296
+ metrics_filename = os.path.join(args.serialization_dir, "metrics_epoch_{}.json".format(idx_epoch))
297
+ with open(metrics_filename, "w", encoding="utf-8") as f:
298
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
299
+
300
+ # save model
301
+ model_filename = os.path.join(args.serialization_dir, "model_epoch_{}.bin".format(idx_epoch))
302
+ model_filename_list.append(model_filename)
303
+ if len(model_filename_list) >= args.num_serialized_models_to_keep:
304
+ model_filename_to_delete = model_filename_list.pop(0)
305
+ os.remove(model_filename_to_delete)
306
+ torch.save(model.state_dict(), model_filename)
307
+
308
+ # early stop
309
+ best_model_filename = os.path.join(args.serialization_dir, "best.bin")
310
+ if best_accuracy is None:
311
+ best_idx_epoch = idx_epoch
312
+ best_accuracy = evaluation_accuracy
313
+ torch.save(model.state_dict(), best_model_filename)
314
+ elif evaluation_accuracy > best_accuracy:
315
+ best_idx_epoch = idx_epoch
316
+ best_accuracy = evaluation_accuracy
317
+ torch.save(model.state_dict(), best_model_filename)
318
+ patience_count = 0
319
+ elif patience_count >= args.patience:
320
+ break
321
+ else:
322
+ patience_count += 1
323
+
324
+ return
325
+
326
+
327
+ if __name__ == "__main__":
328
+ main()
examples/vm_sound_classification8/step_4_train_country_model.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ 只训练 cls_head 部分的参数, 模型的准确率会更低.
5
+ """
6
+ import argparse
7
+ from collections import defaultdict
8
+ import json
9
+ import logging
10
+ from logging.handlers import TimedRotatingFileHandler
11
+ import os
12
+ import platform
13
+ from pathlib import Path
14
+ import sys
15
+ import shutil
16
+ from typing import List
17
+
18
+ pwd = os.path.abspath(os.path.dirname(__file__))
19
+ sys.path.append(os.path.join(pwd, "../../"))
20
+
21
+ import pandas as pd
22
+ import torch
23
+ from torch.utils.data.dataloader import DataLoader
24
+ from tqdm import tqdm
25
+
26
+ from toolbox.torch.modules.loss import FocalLoss, HingeLoss, HingeLinear
27
+ from toolbox.torch.training.metrics.categorical_accuracy import CategoricalAccuracy
28
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
29
+ from toolbox.torch.utils.data.dataset.wave_classifier_excel_dataset import WaveClassifierExcelDataset
30
+ from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveEncoder, ClsHead, WaveClassifier
31
+
32
+
33
+ def get_args():
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
36
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
37
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
38
+
39
+ parser.add_argument("--country", default="en-US", type=str)
40
+ parser.add_argument("--shared_encoder", default="file_dir/global_model/best.bin", type=str)
41
+
42
+ parser.add_argument("--max_epochs", default=100, type=int)
43
+ parser.add_argument("--batch_size", default=64, type=int)
44
+ parser.add_argument("--learning_rate", default=1e-3, type=float)
45
+ parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
46
+ parser.add_argument("--patience", default=5, type=int)
47
+ parser.add_argument("--serialization_dir", default="country_models", type=str)
48
+ parser.add_argument("--seed", default=0, type=int)
49
+
50
+ args = parser.parse_args()
51
+ return args
52
+
53
+
54
+ def logging_config(file_dir: str):
55
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
56
+
57
+ logging.basicConfig(format=fmt,
58
+ datefmt="%m/%d/%Y %H:%M:%S",
59
+ level=logging.DEBUG)
60
+ file_handler = TimedRotatingFileHandler(
61
+ filename=os.path.join(file_dir, "main.log"),
62
+ encoding="utf-8",
63
+ when="D",
64
+ interval=1,
65
+ backupCount=7
66
+ )
67
+ file_handler.setLevel(logging.INFO)
68
+ file_handler.setFormatter(logging.Formatter(fmt))
69
+ logger = logging.getLogger(__name__)
70
+ logger.addHandler(file_handler)
71
+
72
+ return logger
73
+
74
+
75
+ class CollateFunction(object):
76
+ def __init__(self):
77
+ pass
78
+
79
+ def __call__(self, batch: List[dict]):
80
+ array_list = list()
81
+ label_list = list()
82
+ for sample in batch:
83
+ array = sample['waveform']
84
+ label = sample['label']
85
+
86
+ array_list.append(array)
87
+ label_list.append(label)
88
+
89
+ array_list = torch.stack(array_list)
90
+ label_list = torch.stack(label_list)
91
+ return array_list, label_list
92
+
93
+
94
+ collate_fn = CollateFunction()
95
+
96
+
97
+ def main():
98
+ args = get_args()
99
+
100
+ serialization_dir = Path(args.serialization_dir)
101
+ serialization_dir.mkdir(parents=True, exist_ok=True)
102
+
103
+ logger = logging_config(args.serialization_dir)
104
+
105
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
106
+ n_gpu = torch.cuda.device_count()
107
+ logger.info("GPU available: {}; device: {}".format(n_gpu, device))
108
+
109
+ vocabulary = Vocabulary.from_files(args.vocabulary_dir)
110
+
111
+ # datasets
112
+ logger.info("prepare datasets")
113
+ train_dataset = WaveClassifierExcelDataset(
114
+ vocab=vocabulary,
115
+ excel_file=args.train_dataset,
116
+ category=args.country,
117
+ category_field="category",
118
+ label_field="country_labels",
119
+ expected_sample_rate=8000,
120
+ max_wave_value=32768.0,
121
+ )
122
+ valid_dataset = WaveClassifierExcelDataset(
123
+ vocab=vocabulary,
124
+ excel_file=args.valid_dataset,
125
+ category=args.country,
126
+ category_field="category",
127
+ label_field="country_labels",
128
+ expected_sample_rate=8000,
129
+ max_wave_value=32768.0,
130
+ )
131
+
132
+ train_data_loader = DataLoader(
133
+ dataset=train_dataset,
134
+ batch_size=args.batch_size,
135
+ shuffle=True,
136
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
137
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count(),
138
+ collate_fn=collate_fn,
139
+ pin_memory=False,
140
+ # prefetch_factor=64,
141
+ )
142
+ valid_data_loader = DataLoader(
143
+ dataset=valid_dataset,
144
+ batch_size=args.batch_size,
145
+ shuffle=True,
146
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
147
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count(),
148
+ collate_fn=collate_fn,
149
+ pin_memory=False,
150
+ # prefetch_factor=64,
151
+ )
152
+
153
+ # models - classifier
154
+ wave_encoder = WaveEncoder(
155
+ conv1d_block_param_list=[
156
+ {
157
+ 'batch_norm': True,
158
+ 'in_channels': 80,
159
+ 'out_channels': 16,
160
+ 'kernel_size': 3,
161
+ 'stride': 3,
162
+ # 'padding': 'same',
163
+ 'activation': 'relu',
164
+ 'dropout': 0.1,
165
+ },
166
+ {
167
+ # 'batch_norm': True,
168
+ 'in_channels': 16,
169
+ 'out_channels': 16,
170
+ 'kernel_size': 3,
171
+ 'stride': 3,
172
+ # 'padding': 'same',
173
+ 'activation': 'relu',
174
+ 'dropout': 0.1,
175
+ },
176
+ {
177
+ # 'batch_norm': True,
178
+ 'in_channels': 16,
179
+ 'out_channels': 16,
180
+ 'kernel_size': 3,
181
+ 'stride': 3,
182
+ # 'padding': 'same',
183
+ 'activation': 'relu',
184
+ 'dropout': 0.1,
185
+ },
186
+ ],
187
+ mel_spectrogram_param={
188
+ "sample_rate": 8000,
189
+ "n_fft": 512,
190
+ "win_length": 200,
191
+ "hop_length": 80,
192
+ "f_min": 10,
193
+ "f_max": 3800,
194
+ "window_fn": "hamming",
195
+ "n_mels": 80,
196
+ }
197
+ )
198
+
199
+ with open(args.shared_encoder, "rb") as f:
200
+ state_dict = torch.load(f, map_location=device)
201
+ processed_state_dict = dict()
202
+ prefix = "wave_encoder."
203
+ for k, v in state_dict.items():
204
+ if not str(k).startswith(prefix):
205
+ continue
206
+ k = k[len(prefix):]
207
+ processed_state_dict[k] = v
208
+
209
+ wave_encoder.load_state_dict(
210
+ state_dict=processed_state_dict,
211
+ strict=True,
212
+ )
213
+ cls_head = ClsHead(
214
+ input_dim=16,
215
+ num_layers=2,
216
+ hidden_dims=[32, 16],
217
+ activations="relu",
218
+ dropout=0.1,
219
+ num_labels=vocabulary.get_vocab_size(namespace="global_labels")
220
+ )
221
+ model = WaveClassifier(
222
+ wave_encoder=wave_encoder,
223
+ cls_head=cls_head,
224
+ )
225
+ model.wave_encoder.requires_grad_(requires_grad=False)
226
+ model.cls_head.requires_grad_(requires_grad=True)
227
+ model.to(device)
228
+
229
+ # optimizer
230
+ logger.info("prepare optimizer")
231
+ optimizer = torch.optim.Adam(
232
+ model.cls_head.parameters(),
233
+ lr=args.learning_rate,
234
+ )
235
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(
236
+ optimizer,
237
+ step_size=2000
238
+ )
239
+ focal_loss = FocalLoss(
240
+ num_classes=vocabulary.get_vocab_size(namespace=args.country),
241
+ reduction="mean",
242
+ )
243
+ categorical_accuracy = CategoricalAccuracy()
244
+
245
+ # training loop
246
+ best_idx_epoch: int = None
247
+ best_accuracy: float = None
248
+ patience_count = 0
249
+ global_step = 0
250
+ model_filename_list = list()
251
+ for idx_epoch in range(args.max_epochs):
252
+
253
+ # training
254
+ model.train()
255
+ total_loss = 0
256
+ total_examples = 0
257
+ for step, batch in enumerate(tqdm(train_data_loader, desc="Epoch={} (training)".format(idx_epoch))):
258
+ input_ids, label_ids = batch
259
+ input_ids = input_ids.to(device)
260
+ label_ids: torch.LongTensor = label_ids.to(device).long()
261
+
262
+ logits = model.forward(input_ids)
263
+ loss = focal_loss.forward(logits, label_ids.view(-1))
264
+ categorical_accuracy(logits, label_ids)
265
+
266
+ total_loss += loss.item()
267
+ total_examples += input_ids.size(0)
268
+
269
+ optimizer.zero_grad()
270
+ loss.backward()
271
+ optimizer.step()
272
+ lr_scheduler.step()
273
+
274
+ global_step += 1
275
+ training_loss = total_loss / total_examples
276
+ training_loss = round(training_loss, 4)
277
+ training_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
278
+ training_accuracy = round(training_accuracy, 4)
279
+ logger.info("Epoch: {}; training_loss: {}; training_accuracy: {}".format(
280
+ idx_epoch, training_loss, training_accuracy
281
+ ))
282
+
283
+ # evaluation
284
+ model.eval()
285
+ total_loss = 0
286
+ total_examples = 0
287
+ for step, batch in enumerate(tqdm(valid_data_loader, desc="Epoch={} (evaluation)".format(idx_epoch))):
288
+ input_ids, label_ids = batch
289
+ input_ids = input_ids.to(device)
290
+ label_ids: torch.LongTensor = label_ids.to(device).long()
291
+
292
+ with torch.no_grad():
293
+ logits = model.forward(input_ids)
294
+ loss = focal_loss.forward(logits, label_ids.view(-1))
295
+ categorical_accuracy(logits, label_ids)
296
+
297
+ total_loss += loss.item()
298
+ total_examples += input_ids.size(0)
299
+
300
+ evaluation_loss = total_loss / total_examples
301
+ evaluation_loss = round(evaluation_loss, 4)
302
+ evaluation_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
303
+ evaluation_accuracy = round(evaluation_accuracy, 4)
304
+ logger.info("Epoch: {}; evaluation_loss: {}; evaluation_accuracy: {}".format(
305
+ idx_epoch, evaluation_loss, evaluation_accuracy
306
+ ))
307
+
308
+ # save metric
309
+ metrics = {
310
+ "training_loss": training_loss,
311
+ "training_accuracy": training_accuracy,
312
+ "evaluation_loss": evaluation_loss,
313
+ "evaluation_accuracy": evaluation_accuracy,
314
+ "best_idx_epoch": best_idx_epoch,
315
+ "best_accuracy": best_accuracy,
316
+ }
317
+ metrics_filename = os.path.join(args.serialization_dir, "metrics_epoch_{}.json".format(idx_epoch))
318
+ with open(metrics_filename, "w", encoding="utf-8") as f:
319
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
320
+
321
+ # save model
322
+ model_filename = os.path.join(args.serialization_dir, "model_epoch_{}.bin".format(idx_epoch))
323
+ model_filename_list.append(model_filename)
324
+ if len(model_filename_list) >= args.num_serialized_models_to_keep:
325
+ model_filename_to_delete = model_filename_list.pop(0)
326
+ os.remove(model_filename_to_delete)
327
+ torch.save(model.state_dict(), model_filename)
328
+
329
+ # early stop
330
+ best_model_filename = os.path.join(args.serialization_dir, "best.bin")
331
+ if best_accuracy is None:
332
+ best_idx_epoch = idx_epoch
333
+ best_accuracy = evaluation_accuracy
334
+ torch.save(model.state_dict(), best_model_filename)
335
+ elif evaluation_accuracy > best_accuracy:
336
+ best_idx_epoch = idx_epoch
337
+ best_accuracy = evaluation_accuracy
338
+ torch.save(model.state_dict(), best_model_filename)
339
+ patience_count = 0
340
+ elif patience_count >= args.patience:
341
+ break
342
+ else:
343
+ patience_count += 1
344
+
345
+ return
346
+
347
+
348
+ if __name__ == "__main__":
349
+ main()
examples/vm_sound_classification8/step_5_train_union.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from collections import defaultdict
5
+ import json
6
+ import logging
7
+ from logging.handlers import TimedRotatingFileHandler
8
+ import os
9
+ import platform
10
+ from pathlib import Path
11
+ import sys
12
+ import shutil
13
+ from typing import List
14
+
15
+ pwd = os.path.abspath(os.path.dirname(__file__))
16
+ sys.path.append(os.path.join(pwd, "../../"))
17
+
18
+ import pandas as pd
19
+ import torch
20
+ from torch.utils.data.dataloader import DataLoader
21
+ from tqdm import tqdm
22
+
23
+ from toolbox.torch.modules.loss import FocalLoss, HingeLoss, HingeLinear
24
+ from toolbox.torch.training.metrics.categorical_accuracy import CategoricalAccuracy
25
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
26
+ from toolbox.torch.utils.data.dataset.wave_classifier_excel_dataset import WaveClassifierExcelDataset
27
+ from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveEncoder, ClsHead, WaveClassifier
28
+
29
+
30
+ def get_args():
31
+ parser = argparse.ArgumentParser()
32
+ parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
33
+
34
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
35
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
36
+
37
+ parser.add_argument("--max_steps", default=100000, type=int)
38
+ parser.add_argument("--save_steps", default=30, type=int)
39
+
40
+ parser.add_argument("--batch_size", default=1, type=int)
41
+ parser.add_argument("--learning_rate", default=1e-3, type=float)
42
+ parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
43
+ parser.add_argument("--patience", default=5, type=int)
44
+ parser.add_argument("--serialization_dir", default="union", type=str)
45
+ parser.add_argument("--seed", default=0, type=int)
46
+
47
+ parser.add_argument("--num_workers", default=0, type=int)
48
+
49
+ args = parser.parse_args()
50
+ return args
51
+
52
+
53
+ def logging_config(file_dir: str):
54
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
55
+
56
+ logging.basicConfig(format=fmt,
57
+ datefmt="%m/%d/%Y %H:%M:%S",
58
+ level=logging.DEBUG)
59
+ file_handler = TimedRotatingFileHandler(
60
+ filename=os.path.join(file_dir, "main.log"),
61
+ encoding="utf-8",
62
+ when="D",
63
+ interval=1,
64
+ backupCount=7
65
+ )
66
+ file_handler.setLevel(logging.INFO)
67
+ file_handler.setFormatter(logging.Formatter(fmt))
68
+ logger = logging.getLogger(__name__)
69
+ logger.addHandler(file_handler)
70
+
71
+ return logger
72
+
73
+
74
+ class CollateFunction(object):
75
+ def __init__(self):
76
+ pass
77
+
78
+ def __call__(self, batch: List[dict]):
79
+ array_list = list()
80
+ label_list = list()
81
+ for sample in batch:
82
+ array = sample['waveform']
83
+ label = sample['label']
84
+
85
+ array_list.append(array)
86
+ label_list.append(label)
87
+
88
+ array_list = torch.stack(array_list)
89
+ label_list = torch.stack(label_list)
90
+ return array_list, label_list
91
+
92
+
93
+ collate_fn = CollateFunction()
94
+
95
+
96
+ class DatasetIterator(object):
97
+ def __init__(self, data_loader: DataLoader):
98
+ self.data_loader = data_loader
99
+ self.data_loader_iter = iter(self.data_loader)
100
+
101
+ def next(self):
102
+ try:
103
+ result = self.data_loader_iter.__next__()
104
+ except StopIteration:
105
+ self.data_loader_iter = iter(self.data_loader)
106
+ result = self.data_loader_iter.__next__()
107
+ return result
108
+
109
+
110
+ def main():
111
+ args = get_args()
112
+
113
+ serialization_dir = Path(args.serialization_dir)
114
+ serialization_dir.mkdir(parents=True, exist_ok=True)
115
+
116
+ logger = logging_config(args.serialization_dir)
117
+
118
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
119
+ n_gpu = torch.cuda.device_count()
120
+ logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
121
+
122
+ vocabulary = Vocabulary.from_files(args.vocabulary_dir)
123
+ namespaces = vocabulary._token_to_index.keys()
124
+
125
+ # namespace_to_ratio
126
+ max_radio = (len(namespaces) - 1) * 3
127
+ namespace_to_ratio = {n: 1 for n in namespaces}
128
+ namespace_to_ratio["global_labels"] = max_radio
129
+
130
+ # datasets
131
+ logger.info("prepare datasets")
132
+ namespace_to_datasets = dict()
133
+ for namespace in namespaces:
134
+ logger.info("prepare datasets - {}".format(namespace))
135
+ if namespace == "global_labels":
136
+ train_dataset = WaveClassifierExcelDataset(
137
+ vocab=vocabulary,
138
+ excel_file=args.train_dataset,
139
+ category=None,
140
+ category_field="category",
141
+ label_field="global_labels",
142
+ expected_sample_rate=8000,
143
+ max_wave_value=32768.0,
144
+ )
145
+ valid_dataset = WaveClassifierExcelDataset(
146
+ vocab=vocabulary,
147
+ excel_file=args.valid_dataset,
148
+ category=None,
149
+ category_field="category",
150
+ label_field="global_labels",
151
+ expected_sample_rate=8000,
152
+ max_wave_value=32768.0,
153
+ )
154
+ else:
155
+ train_dataset = WaveClassifierExcelDataset(
156
+ vocab=vocabulary,
157
+ excel_file=args.train_dataset,
158
+ category=namespace,
159
+ category_field="category",
160
+ label_field="country_labels",
161
+ expected_sample_rate=8000,
162
+ max_wave_value=32768.0,
163
+ )
164
+ valid_dataset = WaveClassifierExcelDataset(
165
+ vocab=vocabulary,
166
+ excel_file=args.valid_dataset,
167
+ category=namespace,
168
+ category_field="category",
169
+ label_field="country_labels",
170
+ expected_sample_rate=8000,
171
+ max_wave_value=32768.0,
172
+ )
173
+
174
+ train_data_loader = DataLoader(
175
+ dataset=train_dataset,
176
+ batch_size=args.batch_size,
177
+ shuffle=True,
178
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
179
+ # num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
180
+ num_workers=args.num_workers,
181
+ collate_fn=collate_fn,
182
+ pin_memory=False,
183
+ # prefetch_factor=64,
184
+ )
185
+ valid_data_loader = DataLoader(
186
+ dataset=valid_dataset,
187
+ batch_size=args.batch_size,
188
+ shuffle=True,
189
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
190
+ # num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
191
+ num_workers=args.num_workers,
192
+ collate_fn=collate_fn,
193
+ pin_memory=False,
194
+ # prefetch_factor=64,
195
+ )
196
+
197
+ namespace_to_datasets[namespace] = {
198
+ "train_data_loader": train_data_loader,
199
+ "valid_data_loader": valid_data_loader,
200
+ }
201
+
202
+ # datasets iterator
203
+ logger.info("prepare datasets iterator")
204
+ namespace_to_datasets_iter = dict()
205
+ for namespace in namespaces:
206
+ logger.info("prepare datasets iterator - {}".format(namespace))
207
+ train_data_loader = namespace_to_datasets[namespace]["train_data_loader"]
208
+ valid_data_loader = namespace_to_datasets[namespace]["valid_data_loader"]
209
+ namespace_to_datasets_iter[namespace] = {
210
+ "train_data_loader_iter": DatasetIterator(train_data_loader),
211
+ "valid_data_loader_iter": DatasetIterator(valid_data_loader),
212
+ }
213
+
214
+ # models - encoder
215
+ logger.info("prepare models - encoder")
216
+ wave_encoder = WaveEncoder(
217
+ conv2d_block_param_list=[
218
+ {
219
+ "batch_norm": True,
220
+ "in_channels": 1,
221
+ "out_channels": 4,
222
+ "kernel_size": 3,
223
+ "stride": 1,
224
+ # "padding": "same",
225
+ "dilation": 3,
226
+ "activation": "relu",
227
+ "dropout": 0.1,
228
+ },
229
+ {
230
+ # "batch_norm": True,
231
+ "in_channels": 4,
232
+ "out_channels": 4,
233
+ "kernel_size": 5,
234
+ "stride": 2,
235
+ # "padding": "same",
236
+ "dilation": 3,
237
+ "activation": "relu",
238
+ "dropout": 0.1,
239
+ },
240
+ {
241
+ # "batch_norm": True,
242
+ "in_channels": 4,
243
+ "out_channels": 4,
244
+ "kernel_size": 3,
245
+ "stride": 1,
246
+ # "padding": "same",
247
+ "dilation": 2,
248
+ "activation": "relu",
249
+ "dropout": 0.1,
250
+ },
251
+ ],
252
+ mel_spectrogram_param={
253
+ 'sample_rate': 8000,
254
+ 'n_fft': 512,
255
+ 'win_length': 200,
256
+ 'hop_length': 80,
257
+ 'f_min': 10,
258
+ 'f_max': 3800,
259
+ 'window_fn': 'hamming',
260
+ 'n_mels': 80,
261
+ }
262
+ )
263
+
264
+ # models - cls_head
265
+ logger.info("prepare models - cls_head")
266
+ namespace_to_cls_heads = dict()
267
+ for namespace in namespaces:
268
+ logger.info("prepare models - cls_head - {}".format(namespace))
269
+ cls_head = ClsHead(
270
+ input_dim=352,
271
+ num_layers=2,
272
+ hidden_dims=[128, 32],
273
+ activations="relu",
274
+ dropout=0.1,
275
+ num_labels=vocabulary.get_vocab_size(namespace=namespace)
276
+ )
277
+ namespace_to_cls_heads[namespace] = cls_head
278
+
279
+ # models - classifier
280
+ logger.info("prepare models - classifier")
281
+ namespace_to_classifier = dict()
282
+ for namespace in namespaces:
283
+ logger.info("prepare models - classifier - {}".format(namespace))
284
+ cls_head = namespace_to_cls_heads[namespace]
285
+ wave_classifier = WaveClassifier(
286
+ wave_encoder=wave_encoder,
287
+ cls_head=cls_head,
288
+ )
289
+ wave_classifier.to(device)
290
+ namespace_to_classifier[namespace] = wave_classifier
291
+
292
+ # optimizer
293
+ logger.info("prepare optimizer")
294
+ param_optimizer = list()
295
+ param_optimizer.extend(wave_encoder.parameters())
296
+ for _, cls_head in namespace_to_cls_heads.items():
297
+ param_optimizer.extend(cls_head.parameters())
298
+
299
+ optimizer = torch.optim.Adam(
300
+ param_optimizer,
301
+ lr=args.learning_rate,
302
+ )
303
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(
304
+ optimizer,
305
+ step_size=10000
306
+ )
307
+ focal_loss = FocalLoss(
308
+ num_classes=vocabulary.get_vocab_size(namespace="global_labels"),
309
+ reduction="mean",
310
+ )
311
+
312
+ # categorical_accuracy
313
+ logger.info("prepare categorical_accuracy")
314
+ namespace_to_categorical_accuracy = dict()
315
+ for namespace in namespaces:
316
+ categorical_accuracy = CategoricalAccuracy()
317
+ namespace_to_categorical_accuracy[namespace] = categorical_accuracy
318
+
319
+ # training loop
320
+ logger.info("prepare training loop")
321
+
322
+ model_list = list()
323
+ best_idx_step = None
324
+ best_accuracy = None
325
+ patience_count = 0
326
+
327
+ namespace_to_total_loss = defaultdict(float)
328
+ namespace_to_total_examples = defaultdict(int)
329
+ for idx_step in tqdm(range(args.max_steps)):
330
+
331
+ # training one step
332
+ loss: torch.Tensor = None
333
+ for namespace in namespaces:
334
+ train_data_loader_iter = namespace_to_datasets_iter[namespace]["train_data_loader_iter"]
335
+
336
+ ratio = namespace_to_ratio[namespace]
337
+ model = namespace_to_classifier[namespace]
338
+ categorical_accuracy = namespace_to_categorical_accuracy[namespace]
339
+
340
+ model.train()
341
+
342
+ for _ in range(ratio):
343
+ batch = train_data_loader_iter.next()
344
+ input_ids, label_ids = batch
345
+ input_ids = input_ids.to(device)
346
+ label_ids: torch.LongTensor = label_ids.to(device).long()
347
+
348
+ logits = model.forward(input_ids)
349
+ task_loss = focal_loss.forward(logits, label_ids.view(-1))
350
+ categorical_accuracy(logits, label_ids)
351
+
352
+ if loss is None:
353
+ loss = task_loss
354
+ else:
355
+ loss += task_loss
356
+
357
+ namespace_to_total_loss[namespace] += task_loss.item()
358
+ namespace_to_total_examples[namespace] += input_ids.size(0)
359
+
360
+ optimizer.zero_grad()
361
+ loss.backward()
362
+ optimizer.step()
363
+ lr_scheduler.step()
364
+
365
+ # logging
366
+ if (idx_step + 1) % args.save_steps == 0:
367
+ metrics = dict()
368
+
369
+ # training
370
+ for namespace in namespaces:
371
+ total_loss = namespace_to_total_loss[namespace]
372
+ total_examples = namespace_to_total_examples[namespace]
373
+
374
+ training_loss = total_loss / total_examples
375
+ training_loss = round(training_loss, 4)
376
+
377
+ categorical_accuracy = namespace_to_categorical_accuracy[namespace]
378
+
379
+ training_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
380
+ training_accuracy = round(training_accuracy, 4)
381
+ logger.info("Step: {}; namespace: {}; training_loss: {}; training_accuracy: {}".format(
382
+ idx_step, namespace, training_loss, training_accuracy
383
+ ))
384
+ metrics[namespace] = {
385
+ "training_loss": training_loss,
386
+ "training_accuracy": training_accuracy,
387
+ }
388
+ namespace_to_total_loss = defaultdict(float)
389
+ namespace_to_total_examples = defaultdict(int)
390
+
391
+ # evaluation
392
+ for namespace in namespaces:
393
+ valid_data_loader = namespace_to_datasets[namespace]["valid_data_loader"]
394
+
395
+ model = namespace_to_classifier[namespace]
396
+ categorical_accuracy = namespace_to_categorical_accuracy[namespace]
397
+
398
+ model.eval()
399
+
400
+ total_loss = 0
401
+ total_examples = 0
402
+ for step, batch in enumerate(valid_data_loader):
403
+ input_ids, label_ids = batch
404
+ input_ids = input_ids.to(device)
405
+ label_ids: torch.LongTensor = label_ids.to(device).long()
406
+
407
+ with torch.no_grad():
408
+ logits = model.forward(input_ids)
409
+ loss = focal_loss.forward(logits, label_ids.view(-1))
410
+ categorical_accuracy(logits, label_ids)
411
+
412
+ total_loss += loss.item()
413
+ total_examples += input_ids.size(0)
414
+
415
+ evaluation_loss = total_loss / total_examples
416
+ evaluation_loss = round(evaluation_loss, 4)
417
+ evaluation_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
418
+ evaluation_accuracy = round(evaluation_accuracy, 4)
419
+ logger.info("Step: {}; namespace: {}; evaluation_loss: {}; evaluation_accuracy: {}".format(
420
+ idx_step, namespace, evaluation_loss, evaluation_accuracy
421
+ ))
422
+ metrics[namespace] = {
423
+ "evaluation_loss": evaluation_loss,
424
+ "evaluation_accuracy": evaluation_accuracy,
425
+ }
426
+
427
+ # update ratio
428
+ min_accuracy = min([m["evaluation_accuracy"] for m in metrics.values()])
429
+ max_accuracy = max([m["evaluation_accuracy"] for m in metrics.values()])
430
+ width = max_accuracy - min_accuracy
431
+ for namespace, metric in metrics.items():
432
+ evaluation_accuracy = metric["evaluation_accuracy"]
433
+ radio = (max_accuracy - evaluation_accuracy) / width * max_radio
434
+ radio = int(radio)
435
+ namespace_to_ratio[namespace] = radio
436
+
437
+ msg = "".join(["{}: {}; ".format(k, v) for k, v in namespace_to_ratio.items()])
438
+ logger.info("namespace to ratio: {}".format(msg))
439
+
440
+ # save path
441
+ step_dir = serialization_dir / "step-{}".format(idx_step)
442
+ step_dir.mkdir(parents=True, exist_ok=False)
443
+
444
+ # save models
445
+ wave_encoder_filename = step_dir / "wave_encoder.pt"
446
+ torch.save(wave_encoder.state_dict(), wave_encoder_filename)
447
+ for namespace in namespaces:
448
+ cls_head_filename = step_dir / "{}.pt".format(namespace)
449
+ cls_head = namespace_to_cls_heads[namespace]
450
+ torch.save(cls_head.state_dict(), cls_head_filename)
451
+
452
+ model_list.append(step_dir)
453
+ if len(model_list) >= args.num_serialized_models_to_keep:
454
+ model_to_delete: Path = model_list.pop(0)
455
+ shutil.rmtree(model_to_delete.as_posix())
456
+
457
+ # save metric
458
+ this_accuracy = metrics["global_labels"]["evaluation_accuracy"]
459
+ if best_accuracy is None:
460
+ best_idx_step = idx_step
461
+ best_accuracy = this_accuracy
462
+ elif metrics["global_labels"]["evaluation_accuracy"] > best_accuracy:
463
+ best_idx_step = idx_step
464
+ best_accuracy = this_accuracy
465
+ else:
466
+ pass
467
+
468
+ metrics_filename = step_dir / "metrics_epoch.json"
469
+ metrics.update({
470
+ "idx_step": idx_step,
471
+ "best_idx_step": best_idx_step,
472
+ })
473
+ with open(metrics_filename, "w", encoding="utf-8") as f:
474
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
475
+
476
+ # save best
477
+ best_dir = serialization_dir / "best"
478
+ if best_idx_step == idx_step:
479
+ if best_dir.exists():
480
+ shutil.rmtree(best_dir)
481
+ shutil.copytree(step_dir, best_dir)
482
+
483
+ # early stop
484
+ early_stop_flag = False
485
+ if best_idx_step == idx_step:
486
+ patience_count = 0
487
+ else:
488
+ patience_count += 1
489
+ if patience_count >= args.patience:
490
+ early_stop_flag = True
491
+
492
+ # early stop
493
+ if early_stop_flag:
494
+ break
495
+ return
496
+
497
+
498
+ if __name__ == "__main__":
499
+ main()
examples/vm_sound_classification8/stop.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ kill -9 `ps -aef | grep 'vm_sound_classification/bin/python3' | grep -v grep | awk '{print $2}' | sed 's/\n/ /'`
install.sh ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # bash install.sh --stage 2 --stop_stage 2 --system_version centos
4
+
5
+
6
+ python_version=3.8.10
7
+ system_version="centos";
8
+
9
+ verbose=true;
10
+ stage=-1
11
+ stop_stage=0
12
+
13
+
14
+ # parse options
15
+ while true; do
16
+ [ -z "${1:-}" ] && break; # break if there are no arguments
17
+ case "$1" in
18
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
19
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
20
+ old_value="(eval echo \\$$name)";
21
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
22
+ was_bool=true;
23
+ else
24
+ was_bool=false;
25
+ fi
26
+
27
+ # Set the variable to the right value-- the escaped quotes make it work if
28
+ # the option had spaces, like --cmd "queue.pl -sync y"
29
+ eval "${name}=\"$2\"";
30
+
31
+ # Check that Boolean-valued arguments are really Boolean.
32
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
33
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
34
+ exit 1;
35
+ fi
36
+ shift 2;
37
+ ;;
38
+
39
+ *) break;
40
+ esac
41
+ done
42
+
43
+ work_dir="$(pwd)"
44
+
45
+
46
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
47
+ $verbose && echo "stage 1: install python"
48
+ cd "${work_dir}" || exit 1;
49
+
50
+ sh ./script/install_python.sh --python_version "${python_version}" --system_version "${system_version}"
51
+ fi
52
+
53
+
54
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
55
+ $verbose && echo "stage 2: create virtualenv"
56
+
57
+ # /usr/local/python-3.6.5/bin/virtualenv vm_sound_classification
58
+ # source /data/local/bin/vm_sound_classification/bin/activate
59
+ /usr/local/python-${python_version}/bin/pip3 install virtualenv
60
+ mkdir -p /data/local/bin
61
+ cd /data/local/bin || exit 1;
62
+ /usr/local/python-${python_version}/bin/virtualenv vm_sound_classification
63
+
64
+ fi
main.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from functools import lru_cache
5
+ import json
6
+ from pathlib import Path
7
+ import platform
8
+ import shutil
9
+ import tempfile
10
+ import zipfile
11
+
12
+ import gradio as gr
13
+ import numpy as np
14
+ import torch
15
+
16
+ from project_settings import environment, project_path
17
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
18
+
19
+
20
+ def get_args():
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument(
23
+ "--examples_dir",
24
+ default=(project_path / "data/examples").as_posix(),
25
+ type=str
26
+ )
27
+ parser.add_argument(
28
+ "--trained_model_dir",
29
+ default=(project_path / "trained_models").as_posix(),
30
+ type=str
31
+ )
32
+ parser.add_argument(
33
+ "--server_port",
34
+ default=environment.get("server_port", 7860),
35
+ type=int
36
+ )
37
+ args = parser.parse_args()
38
+ return args
39
+
40
+
41
+ @lru_cache(maxsize=100)
42
+ def load_model(model_file: Path):
43
+ with zipfile.ZipFile(model_file, "r") as f_zip:
44
+ out_root = Path(tempfile.gettempdir()) / "vm_sound_classification"
45
+ if out_root.exists():
46
+ shutil.rmtree(out_root.as_posix())
47
+ out_root.mkdir(parents=True, exist_ok=True)
48
+ f_zip.extractall(path=out_root)
49
+
50
+ tgt_path = out_root / model_file.stem
51
+ jit_model_file = tgt_path / "trace_model.zip"
52
+ vocab_path = tgt_path / "vocabulary"
53
+
54
+ vocabulary = Vocabulary.from_files(vocab_path.as_posix())
55
+
56
+ with open(jit_model_file.as_posix(), "rb") as f:
57
+ model = torch.jit.load(f)
58
+ model.eval()
59
+
60
+ shutil.rmtree(tgt_path)
61
+
62
+ d = {
63
+ "model": model,
64
+ "vocabulary": vocabulary
65
+ }
66
+ return d
67
+
68
+
69
+ def click_button(audio: np.ndarray,
70
+ model_name: str,
71
+ ground_true: str) -> str:
72
+
73
+ sample_rate, signal = audio
74
+
75
+ model_file = "trained_models/{}.zip".format(model_name)
76
+ model_file = Path(model_file)
77
+ d = load_model(model_file)
78
+
79
+ model = d["model"]
80
+ vocabulary = d["vocabulary"]
81
+
82
+ inputs = signal / (1 << 15)
83
+ inputs = torch.tensor(inputs, dtype=torch.float32)
84
+ inputs = torch.unsqueeze(inputs, dim=0)
85
+
86
+ with torch.no_grad():
87
+ logits = model.forward(inputs)
88
+ probs = torch.nn.functional.softmax(logits, dim=-1)
89
+ label_idx = torch.argmax(probs, dim=-1)
90
+
91
+ label_idx = label_idx.cpu()
92
+ probs = probs.cpu()
93
+
94
+ label_idx = label_idx.numpy()[0]
95
+ prob = probs.numpy()[0][label_idx]
96
+
97
+ label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
98
+
99
+ return label_str, round(prob, 4)
100
+
101
+
102
+ def main():
103
+ args = get_args()
104
+
105
+ examples_dir = Path(args.examples_dir)
106
+ trained_model_dir = Path(args.trained_model_dir)
107
+
108
+ # examples
109
+ examples = list()
110
+ for filename in examples_dir.glob("*/*/*.wav"):
111
+ language = filename.parts[-3]
112
+ label = filename.parts[-2]
113
+
114
+ examples.append([
115
+ filename.as_posix(),
116
+ language,
117
+ label
118
+ ])
119
+
120
+ # models
121
+ model_choices = list()
122
+ for filename in trained_model_dir.glob("*.zip"):
123
+ model_name = filename.stem
124
+ model_choices.append(model_name)
125
+
126
+ # ui
127
+ brief_description = """
128
+ 国际语音智能外呼系统, 电话声音分类.
129
+ """
130
+
131
+ # ui
132
+ with gr.Blocks() as blocks:
133
+ gr.Markdown(value=brief_description)
134
+
135
+ with gr.Row():
136
+ with gr.Column(scale=3):
137
+ c_audio = gr.Audio(label="audio")
138
+ with gr.Row():
139
+ with gr.Column(scale=3):
140
+ c_model_name = gr.Dropdown(choices=model_choices, value=model_choices[0], label="language")
141
+ with gr.Column(scale=3):
142
+ c_ground_true = gr.Textbox(label="ground_true")
143
+
144
+ c_button = gr.Button("run", variant="primary")
145
+ with gr.Column(scale=3):
146
+ c_label = gr.Textbox(label="label")
147
+ c_probability = gr.Number(label="probability")
148
+
149
+ gr.Examples(
150
+ examples,
151
+ inputs=[c_audio, c_model_name, c_ground_true],
152
+ outputs=[c_label, c_probability],
153
+ fn=click_button,
154
+ examples_per_page=5,
155
+ )
156
+
157
+ c_button.click(
158
+ click_button,
159
+ inputs=[c_audio, c_model_name, c_ground_true],
160
+ outputs=[c_label, c_probability],
161
+ )
162
+
163
+ blocks.queue().launch(
164
+ share=False if platform.system() == "Windows" else False,
165
+ server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
166
+ server_port=args.server_port
167
+ )
168
+ return
169
+
170
+
171
+ if __name__ == "__main__":
172
+ main()
project_settings.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ from pathlib import Path
5
+
6
+ from toolbox.os.environment import EnvironmentManager
7
+
8
+
9
+ project_path = os.path.abspath(os.path.dirname(__file__))
10
+ project_path = Path(project_path)
11
+
12
+ environment = EnvironmentManager(
13
+ path=os.path.join(project_path, "dotenv"),
14
+ env=os.environ.get("environment", "dev"),
15
+ )
16
+
17
+
18
+ if __name__ == '__main__':
19
+ pass
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.3.0
2
+ torchaudio==2.3.0
3
+ fsspec==2024.5.0
4
+ librosa==0.10.2
5
+ pandas==2.0.3
6
+ openpyxl==3.0.9
7
+ xlrd==1.2.0
8
+ tqdm==4.66.4
9
+ overrides==1.9.0
10
+ pyyaml==6.0.1
11
+ evaluate==0.4.2
12
+ gradio
script/install_nvidia_driver.sh ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ #GPU驱动安装需要先将原有的显示关闭, 重启机器, 再进行安装.
3
+ #参考链接:
4
+ #https://blog.csdn.net/kingschan/article/details/19033595
5
+ #https://blog.csdn.net/HaixWang/article/details/90408538
6
+ #
7
+ #>>> yum install -y pciutils
8
+ #查看 linux 机器上是否有 GPU
9
+ #lspci |grep -i nvidia
10
+ #
11
+ #>>> lspci |grep -i nvidia
12
+ #00:08.0 3D controller: NVIDIA Corporation TU104GL [Tesla T4] (rev a1)
13
+ #
14
+ #
15
+ #NVIDIA 驱动程序下载
16
+ #先在 pytorch 上查看应该用什么 cuda 版本, 再安装对应的 cuda-toolkit cuda.
17
+ #再根据 gpu 版本下载安装对应的 nvidia 驱动
18
+ #
19
+ ## pytorch 版本
20
+ #https://pytorch.org/get-started/locally/
21
+ #
22
+ ## CUDA 下载 (好像不需要这个)
23
+ #https://developer.nvidia.com/cuda-toolkit-archive
24
+ #
25
+ ## nvidia 驱动
26
+ #https://www.nvidia.cn/Download/index.aspx?lang=cn
27
+ #http://www.nvidia.com/Download/index.aspx
28
+ #
29
+ #在下方的下拉列表中进行选择,针对您的 NVIDIA 产品确定合适的驱动。
30
+ #产品类型:
31
+ #Data Center / Tesla
32
+ #产品系列:
33
+ #T-Series
34
+ #产品家族:
35
+ #Tesla T4
36
+ #操作系统:
37
+ #Linux 64-bit
38
+ #CUDA Toolkit:
39
+ #10.2
40
+ #语言:
41
+ #Chinese (Simpleified)
42
+ #
43
+ #
44
+ #>>> mkdir -p /data/tianxing
45
+ #>>> cd /data/tianxing
46
+ #>>> wget https://cn.download.nvidia.com/tesla/440.118.02/NVIDIA-Linux-x86_64-440.118.02.run
47
+ #>>> sh NVIDIA-Linux-x86_64-440.118.02.run
48
+ #
49
+ ## 异常:
50
+ #ERROR: The Nouveau kernel driver is currently in use by your system. This driver is incompatible with the NVIDIA driver, and must be disabled before proceeding. Please consult the NVIDIA driver README and your
51
+ #Linux distribution's documentation for details on how to correctly disable the Nouveau kernel driver.
52
+ #[OK]
53
+ #
54
+ #For some distributions, Nouveau can be disabled by adding a file in the modprobe configuration directory. Would you like nvidia-installer to attempt to create this modprobe file for you?
55
+ #[NO]
56
+ #
57
+ #ERROR: Installation has failed. Please see the file '/var/log/nvidia-installer.log' for details. You may find suggestions on fixing installation problems in the README available on the Linux driver download
58
+ #page at www.nvidia.com.
59
+ #[OK]
60
+ #
61
+ ## 参考链接:
62
+ #https://blog.csdn.net/kingschan/article/details/19033595
63
+ #
64
+ ## 禁用原有的显卡驱动 nouveau
65
+ #>>> echo -e "blacklist nouveau\noptions nouveau modeset=0\n" > /etc/modprobe.d/blacklist-nouveau.conf
66
+ #>>> sudo dracut --force
67
+ ## 重启
68
+ #>>> reboot
69
+ #
70
+ #>>> init 3
71
+ #>>> sh NVIDIA-Linux-x86_64-440.118.02.run
72
+ #
73
+ ## 异常
74
+ #ERROR: Unable to find the kernel source tree for the currently running kernel. Please make sure you have installed the kernel source files for your kernel and that they are properly configured; on Red Hat Linux systems, for example, be sure you have the 'kernel-source' or 'kernel-devel' RPM installed. If you know the correct kernel source files are installed, you may specify the kernel source path with the '--kernel-source-path' command line option.
75
+ #[OK]
76
+ #ERROR: Installation has failed. Please see the file '/var/log/nvidia-installer.log' for details. You may find suggestions on fixing installation problems in the README available on the Linux driver download
77
+ #page at www.nvidia.com.
78
+ #[OK]
79
+ #
80
+ ## 参考链接
81
+ ## https://blog.csdn.net/HaixWang/article/details/90408538
82
+ #
83
+ #>>> uname -r
84
+ #3.10.0-1160.49.1.el7.x86_64
85
+ #>>> yum install kernel-devel kernel-headers -y
86
+ #>>> yum info kernel-devel kernel-headers
87
+ #>>> yum install -y "kernel-devel-uname-r == $(uname -r)"
88
+ #>>> yum -y distro-sync
89
+ #
90
+ #>>> sh NVIDIA-Linux-x86_64-440.118.02.run
91
+ #
92
+ ## 安装成功
93
+ #WARNING: nvidia-installer was forced to guess the X library path '/usr/lib64' and X module path '/usr/lib64/xorg/modules'; these paths were not queryable from the system. If X fails to find the NVIDIA X driver
94
+ #module, please install the `pkg-config` utility and the X.Org SDK/development package for your distribution and reinstall the driver.
95
+ #[OK]
96
+ #Install NVIDIA's 32-bit compatibility libraries?
97
+ #[YES]
98
+ #Installation of the kernel module for the NVIDIA Accelerated Graphics Driver for Linux-x86_64 (version 440.118.02) is now complete.
99
+ #[OK]
100
+ #
101
+ #
102
+ ## 查看 GPU 使用情况; watch -n 1 -d nvidia-smi 每1秒刷新一次.
103
+ #>>> nvidia-smi
104
+ #Thu Mar 9 12:00:37 2023
105
+ #+-----------------------------------------------------------------------------+
106
+ #| NVIDIA-SMI 440.118.02 Driver Version: 440.118.02 CUDA Version: 10.2 |
107
+ #|-------------------------------+----------------------+----------------------+
108
+ #| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
109
+ #| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
110
+ #|===============================+======================+======================|
111
+ #| 0 Tesla T4 Off | 00000000:00:08.0 Off | Off |
112
+ #| N/A 54C P0 22W / 70W | 0MiB / 16127MiB | 0% Default |
113
+ #+-------------------------------+----------------------+----------------------+
114
+ #
115
+ #+-----------------------------------------------------------------------------+
116
+ #| Processes: GPU Memory |
117
+ #| GPU PID Type Process name Usage |
118
+ #|=============================================================================|
119
+ #| No running processes found |
120
+ #+-----------------------------------------------------------------------------+
121
+ #
122
+ #
123
+
124
+ # params
125
+ stage=1
126
+ nvidia_driver_filename=https://cn.download.nvidia.com/tesla/440.118.02/NVIDIA-Linux-x86_64-440.118.02.run
127
+
128
+ # parse options
129
+ while true; do
130
+ [ -z "${1:-}" ] && break; # break if there are no arguments
131
+ case "$1" in
132
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
133
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
134
+ old_value="(eval echo \\$$name)";
135
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
136
+ was_bool=true;
137
+ else
138
+ was_bool=false;
139
+ fi
140
+
141
+ # Set the variable to the right value-- the escaped quotes make it work if
142
+ # the option had spaces, like --cmd "queue.pl -sync y"
143
+ eval "${name}=\"$2\"";
144
+
145
+ # Check that Boolean-valued arguments are really Boolean.
146
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
147
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
148
+ exit 1;
149
+ fi
150
+ shift 2;
151
+ ;;
152
+
153
+ *) break;
154
+ esac
155
+ done
156
+
157
+ echo "stage: ${stage}";
158
+
159
+ yum -y install wget
160
+ yum -y install sudo
161
+
162
+ if [ ${stage} -eq 0 ]; then
163
+ mkdir -p /data/dep
164
+ cd /data/dep || echo 1;
165
+ wget -P /data/dep ${nvidia_driver_filename}
166
+
167
+ echo -e "blacklist nouveau\noptions nouveau modeset=0\n" > /etc/modprobe.d/blacklist-nouveau.conf
168
+ sudo dracut --force
169
+ # 重启
170
+ reboot
171
+ elif [ ${stage} -eq 1 ]; then
172
+ init 3
173
+
174
+ yum install -y kernel-devel kernel-headers
175
+ yum info kernel-devel kernel-headers
176
+ yum install -y "kernel-devel-uname-r == $(uname -r)"
177
+ yum -y distro-sync
178
+
179
+ cd /data/dep || echo 1;
180
+
181
+ # 安装时, 需要回车三下.
182
+ sh NVIDIA-Linux-x86_64-440.118.02.run
183
+ nvidia-smi
184
+ fi
script/install_python.sh ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # 参数:
4
+ python_version="3.6.5";
5
+ system_version="centos";
6
+
7
+
8
+ # parse options
9
+ while true; do
10
+ [ -z "${1:-}" ] && break; # break if there are no arguments
11
+ case "$1" in
12
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
13
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
14
+ old_value="(eval echo \\$$name)";
15
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
16
+ was_bool=true;
17
+ else
18
+ was_bool=false;
19
+ fi
20
+
21
+ # Set the variable to the right value-- the escaped quotes make it work if
22
+ # the option had spaces, like --cmd "queue.pl -sync y"
23
+ eval "${name}=\"$2\"";
24
+
25
+ # Check that Boolean-valued arguments are really Boolean.
26
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
27
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
28
+ exit 1;
29
+ fi
30
+ shift 2;
31
+ ;;
32
+
33
+ *) break;
34
+ esac
35
+ done
36
+
37
+ echo "python_version: ${python_version}";
38
+ echo "system_version: ${system_version}";
39
+
40
+
41
+ if [ ${system_version} = "centos" ]; then
42
+ # 安装 python 开发编译环境
43
+ yum -y groupinstall "Development tools"
44
+ yum -y install zlib-devel bzip2-devel openssl-devel ncurses-devel sqlite-devel readline-devel tk-devel gdbm-devel db4-devel libpcap-devel xz-devel
45
+ yum install libffi-devel -y
46
+ yum install -y wget
47
+ yum install -y make
48
+
49
+ mkdir -p /data/dep
50
+ cd /data/dep || exit 1;
51
+ if [ ! -e Python-${python_version}.tgz ]; then
52
+ wget -P /data/dep https://www.python.org/ftp/python/${python_version}/Python-${python_version}.tgz
53
+ fi
54
+
55
+ cd /data/dep || exit 1;
56
+ if [ ! -d Python-${python_version} ]; then
57
+ tar -zxvf Python-${python_version}.tgz
58
+ cd /data/dep/Python-${python_version} || exit 1;
59
+ fi
60
+
61
+ mkdir /usr/local/python-${python_version}
62
+ ./configure --prefix=/usr/local/python-${python_version}
63
+ make && make install
64
+
65
+ /usr/local/python-${python_version}/bin/python3 -V
66
+ /usr/local/python-${python_version}/bin/pip3 -V
67
+
68
+ rm -rf /usr/local/bin/python3
69
+ rm -rf /usr/local/bin/pip3
70
+ ln -s /usr/local/python-${python_version}/bin/python3 /usr/local/bin/python3
71
+ ln -s /usr/local/python-${python_version}/bin/pip3 /usr/local/bin/pip3
72
+
73
+ python3 -V
74
+ pip3 -V
75
+
76
+ elif [ ${system_version} = "ubuntu" ]; then
77
+ # 安装 python 开发编译环境
78
+ # https://zhuanlan.zhihu.com/p/506491209
79
+
80
+ # 刷新软件包目录
81
+ sudo apt update
82
+ # 列出当前可用的更新
83
+ sudo apt list --upgradable
84
+ # 如上一步提示有可以更新的项目,则执行更新
85
+ sudo apt -y upgrade
86
+ # 安装 GCC 编译器
87
+ sudo apt install gcc
88
+ # 检查安装是否成功
89
+ gcc -v
90
+
91
+ # 安装依赖
92
+ sudo apt install -y build-essential zlib1g-dev libncurses5-dev libgdbm-dev libnss3-dev libssl-dev libreadline-dev libffi-dev libbz2-dev liblzma-dev sqlite3 libsqlite3-dev tk-dev uuid-dev libgdbm-compat-dev
93
+
94
+ mkdir -p /data/dep
95
+ cd /data/dep || exit 1;
96
+ if [ ! -e Python-${python_version}.tgz ]; then
97
+ # sudo wget -P /data/dep https://www.python.org/ftp/python/3.6.5/Python-3.6.5.tgz
98
+ sudo wget -P /data/dep https://www.python.org/ftp/python/${python_version}/Python-${python_version}.tgz
99
+ fi
100
+
101
+ cd /data/dep || exit 1;
102
+ if [ ! -d Python-${python_version} ]; then
103
+ # tar -zxvf Python-3.6.5.tgz
104
+ tar -zxvf Python-${python_version}.tgz
105
+ # cd /data/dep/Python-3.6.5
106
+ cd /data/dep/Python-${python_version} || exit 1;
107
+ fi
108
+
109
+ # mkdir /usr/local/python-3.6.5
110
+ mkdir /usr/local/python-${python_version}
111
+
112
+ # 检查依赖与配置编译
113
+ # sudo ./configure --prefix=/usr/local/python-3.6.5 --enable-optimizations --with-lto --enable-shared
114
+ sudo ./configure --prefix=/usr/local/python-${python_version} --enable-optimizations --with-lto --enable-shared
115
+ cpu_count=$(cat /proc/cpuinfo | grep processor | wc -l)
116
+ # sudo make -j 4
117
+ sudo make -j "${cpu_count}"
118
+
119
+ /usr/local/python-${python_version}/bin/python3 -V
120
+ /usr/local/python-${python_version}/bin/pip3 -V
121
+
122
+ rm -rf /usr/local/bin/python3
123
+ rm -rf /usr/local/bin/pip3
124
+ ln -s /usr/local/python-${python_version}/bin/python3 /usr/local/bin/python3
125
+ ln -s /usr/local/python-${python_version}/bin/pip3 /usr/local/bin/pip3
126
+
127
+ python3 -V
128
+ pip3 -V
129
+ fi
toolbox/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == '__main__':
5
+ pass
toolbox/json/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/json/misc.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Callable
4
+
5
+
6
+ def traverse(js, callback: Callable, *args, **kwargs):
7
+ if isinstance(js, list):
8
+ result = list()
9
+ for l in js:
10
+ l = traverse(l, callback, *args, **kwargs)
11
+ result.append(l)
12
+ return result
13
+ elif isinstance(js, tuple):
14
+ result = list()
15
+ for l in js:
16
+ l = traverse(l, callback, *args, **kwargs)
17
+ result.append(l)
18
+ return tuple(result)
19
+ elif isinstance(js, dict):
20
+ result = dict()
21
+ for k, v in js.items():
22
+ k = traverse(k, callback, *args, **kwargs)
23
+ v = traverse(v, callback, *args, **kwargs)
24
+ result[k] = v
25
+ return result
26
+ elif isinstance(js, int):
27
+ return callback(js, *args, **kwargs)
28
+ elif isinstance(js, str):
29
+ return callback(js, *args, **kwargs)
30
+ else:
31
+ return js
32
+
33
+
34
+ def demo1():
35
+ d = {
36
+ "env": "ppe",
37
+ "mysql_connect": {
38
+ "host": "$mysql_connect_host",
39
+ "port": 3306,
40
+ "user": "callbot",
41
+ "password": "NxcloudAI2021!",
42
+ "database": "callbot_ppe",
43
+ "charset": "utf8"
44
+ },
45
+ "es_connect": {
46
+ "hosts": ["10.20.251.8"],
47
+ "http_auth": ["elastic", "ElasticAI2021!"],
48
+ "port": 9200
49
+ }
50
+ }
51
+
52
+ def callback(s):
53
+ if isinstance(s, str) and s.startswith('$'):
54
+ return s[1:]
55
+ return s
56
+
57
+ result = traverse(d, callback=callback)
58
+ print(result)
59
+ return
60
+
61
+
62
+ if __name__ == '__main__':
63
+ demo1()
toolbox/os/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/os/command.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+
5
+
6
+ class Command(object):
7
+ custom_command = [
8
+ "cd"
9
+ ]
10
+
11
+ @staticmethod
12
+ def _get_cmd(command):
13
+ command = str(command).strip()
14
+ if command == "":
15
+ return None
16
+ cmd_and_args = command.split(sep=" ")
17
+ cmd = cmd_and_args[0]
18
+ args = " ".join(cmd_and_args[1:])
19
+ return cmd, args
20
+
21
+ @classmethod
22
+ def popen(cls, command):
23
+ cmd, args = cls._get_cmd(command)
24
+ if cmd in cls.custom_command:
25
+ method = getattr(cls, cmd)
26
+ return method(args)
27
+ else:
28
+ resp = os.popen(command)
29
+ result = resp.read()
30
+ resp.close()
31
+ return result
32
+
33
+ @classmethod
34
+ def cd(cls, args):
35
+ if args.startswith("/"):
36
+ os.chdir(args)
37
+ else:
38
+ pwd = os.getcwd()
39
+ path = os.path.join(pwd, args)
40
+ os.chdir(path)
41
+
42
+ @classmethod
43
+ def system(cls, command):
44
+ return os.system(command)
45
+
46
+ def __init__(self):
47
+ pass
48
+
49
+
50
+ def ps_ef_grep(keyword: str):
51
+ cmd = "ps -ef | grep {}".format(keyword)
52
+ rows = Command.popen(cmd)
53
+ rows = str(rows).split("\n")
54
+ rows = [row for row in rows if row.__contains__(keyword) and not row.__contains__("grep")]
55
+ return rows
56
+
57
+
58
+ if __name__ == "__main__":
59
+ pass
toolbox/os/environment.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import json
4
+ import os
5
+
6
+ from dotenv import load_dotenv
7
+ from dotenv.main import DotEnv
8
+
9
+ from toolbox.json.misc import traverse
10
+
11
+
12
+ class EnvironmentManager(object):
13
+ def __init__(self, path, env, override=False):
14
+ filename = os.path.join(path, '{}.env'.format(env))
15
+ self.filename = filename
16
+
17
+ load_dotenv(
18
+ dotenv_path=filename,
19
+ override=override
20
+ )
21
+
22
+ self._environ = dict()
23
+
24
+ def open_dotenv(self, filename: str = None):
25
+ filename = filename or self.filename
26
+ dotenv = DotEnv(
27
+ dotenv_path=filename,
28
+ stream=None,
29
+ verbose=False,
30
+ interpolate=False,
31
+ override=False,
32
+ encoding="utf-8",
33
+ )
34
+ result = dotenv.dict()
35
+ return result
36
+
37
+ def get(self, key, default=None, dtype=str):
38
+ result = os.environ.get(key)
39
+ if result is None:
40
+ if default is None:
41
+ result = None
42
+ else:
43
+ result = default
44
+ else:
45
+ result = dtype(result)
46
+ self._environ[key] = result
47
+ return result
48
+
49
+
50
+ _DEFAULT_DTYPE_MAP = {
51
+ 'int': int,
52
+ 'float': float,
53
+ 'str': str,
54
+ 'json.loads': json.loads
55
+ }
56
+
57
+
58
+ class JsonConfig(object):
59
+ """
60
+ 将 json 中, 形如 `$float:threshold` 的值, 处理为:
61
+ 从环境变量中查到 threshold, 再将其转换为 float 类型.
62
+ """
63
+ def __init__(self, dtype_map: dict = None, environment: EnvironmentManager = None):
64
+ self.dtype_map = dtype_map or _DEFAULT_DTYPE_MAP
65
+ self.environment = environment or os.environ
66
+
67
+ def sanitize_by_filename(self, filename: str):
68
+ with open(filename, 'r', encoding='utf-8') as f:
69
+ js = json.load(f)
70
+
71
+ return self.sanitize_by_json(js)
72
+
73
+ def sanitize_by_json(self, js):
74
+ js = traverse(
75
+ js,
76
+ callback=self.sanitize,
77
+ environment=self.environment
78
+ )
79
+ return js
80
+
81
+ def sanitize(self, string, environment):
82
+ """支持 $ 符开始的, 环境变量配置"""
83
+ if isinstance(string, str) and string.startswith('$'):
84
+ dtype, key = string[1:].split(':')
85
+ dtype = self.dtype_map[dtype]
86
+
87
+ value = environment.get(key)
88
+ if value is None:
89
+ raise AssertionError('environment not exist. key: {}'.format(key))
90
+
91
+ value = dtype(value)
92
+ result = value
93
+ else:
94
+ result = string
95
+ return result
96
+
97
+
98
+ def demo1():
99
+ import json
100
+
101
+ from project_settings import project_path
102
+
103
+ environment = EnvironmentManager(
104
+ path=os.path.join(project_path, 'server/callbot_server/dotenv'),
105
+ env='dev',
106
+ )
107
+ init_scenes = environment.get(key='init_scenes', dtype=json.loads)
108
+ print(init_scenes)
109
+ print(environment._environ)
110
+ return
111
+
112
+
113
+ if __name__ == '__main__':
114
+ demo1()
toolbox/os/other.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import inspect
3
+
4
+
5
+ def pwd():
6
+ """你在哪个文件调用此函数, 它就会返回那个文件所在的 dir 目标"""
7
+ frame = inspect.stack()[1]
8
+ module = inspect.getmodule(frame[0])
9
+ return os.path.dirname(os.path.abspath(module.__file__))
toolbox/torch/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == '__main__':
5
+ pass
toolbox/torch/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torch/modules/gaussian_mixture.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/georgepar/gmmhmm-pytorch/blob/master/gmm.py
5
+ https://github.com/ldeecke/gmm-torch
6
+ """
7
+ import math
8
+
9
+ from sklearn import cluster
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+
14
+ class GaussianMixtureModel(nn.Module):
15
+ def __init__(self,
16
+ n_mixtures: int,
17
+ n_features: int,
18
+ init: str = "random",
19
+ device: str = 'cpu',
20
+ n_iter: int = 1000,
21
+ delta: float = 1e-3,
22
+ warm_start: bool = False,
23
+ ):
24
+ super(GaussianMixtureModel, self).__init__()
25
+ self.n_mixtures = n_mixtures
26
+ self.n_features = n_features
27
+ self.init = init
28
+ self.device = device
29
+ self.n_iter = n_iter
30
+ self.delta = delta
31
+ self.warm_start = warm_start
32
+
33
+ if init not in ('kmeans', 'random'):
34
+ raise AssertionError
35
+
36
+ self.mu = nn.Parameter(
37
+ torch.Tensor(n_mixtures, n_features),
38
+ requires_grad=False,
39
+ )
40
+
41
+ self.sigma = None
42
+
43
+ # the weight of each gaussian
44
+ self.pi = nn.Parameter(
45
+ torch.Tensor(n_mixtures),
46
+ requires_grad=False
47
+ )
48
+
49
+ self.converged_ = False
50
+ self.eps = 1e-6
51
+ self.delta = delta
52
+ self.warm_start = warm_start
53
+ self.n_iter = n_iter
54
+
55
+ def reset_sigma(self):
56
+ raise NotImplementedError
57
+
58
+ def estimate_precisions(self):
59
+ raise NotImplementedError
60
+
61
+ def log_prob(self, x):
62
+ raise NotImplementedError
63
+
64
+ def weighted_log_prob(self, x):
65
+ log_prob = self.log_prob(x)
66
+ weighted_log_prob = log_prob + torch.log(self.pi)
67
+ return weighted_log_prob
68
+
69
+ def log_likelihood(self, x):
70
+ weighted_log_prob = self.weighted_log_prob(x)
71
+ per_sample_log_likelihood = torch.logsumexp(weighted_log_prob, dim=1)
72
+ log_likelihood = torch.sum(per_sample_log_likelihood)
73
+ return log_likelihood
74
+
75
+ def e_step(self, x):
76
+ weighted_log_prob = self.weighted_log_prob(x)
77
+ weighted_log_prob = weighted_log_prob.unsqueeze(dim=-1)
78
+ log_likelihood = torch.logsumexp(weighted_log_prob, dim=1, keepdim=True)
79
+ q = weighted_log_prob - log_likelihood
80
+ return q.squeeze()
81
+
82
+ def m_step(self, x, q):
83
+ x = x.unsqueeze(dim=1)
84
+
85
+ return
86
+
87
+ def estimate_mu(self, x, pi, responsibilities):
88
+ nk = pi * x.size(0)
89
+ mu = torch.sum(responsibilities * x, dim=0, keepdim=True) / nk
90
+ return mu
91
+
92
+ def estimate_pi(self, x, responsibilities):
93
+ pi = torch.sum(responsibilities, dim=0, keepdim=True) + self.eps
94
+ pi = pi / x.size(0)
95
+ return pi
96
+
97
+ def reset_parameters(self, x=None):
98
+ if self.init == 'random' or x is None:
99
+ self.mu.normal_()
100
+ self.reset_sigma()
101
+ self.pi.fill_(1.0 / self.n_mixtures)
102
+ elif self.init == 'kmeans':
103
+ centroids = cluster.KMeans(n_clusters=self.n_mixtures, n_init=1).fit(x).cluster_centers_
104
+ centroids = torch.tensor(centroids).to(self.device)
105
+ self.update_(mu=centroids)
106
+ else:
107
+ raise NotImplementedError
108
+
109
+
110
+ class DiagonalCovarianceGMM(GaussianMixtureModel):
111
+ def __init__(self,
112
+ n_mixtures: int,
113
+ n_features: int,
114
+ init: str = "random",
115
+ device: str = 'cpu',
116
+ n_iter: int = 1000,
117
+ delta: float = 1e-3,
118
+ warm_start: bool = False,
119
+ ):
120
+ super(DiagonalCovarianceGMM, self).__init__(
121
+ n_mixtures=n_mixtures,
122
+ n_features=n_features,
123
+ init=init,
124
+ device=device,
125
+ n_iter=n_iter,
126
+ delta=delta,
127
+ warm_start=warm_start,
128
+ )
129
+ self.sigma = nn.Parameter(
130
+ torch.Tensor(n_mixtures, n_features), requires_grad=False
131
+ )
132
+ self.reset_parameters()
133
+ self.to(self.device)
134
+
135
+ def reset_sigma(self):
136
+ self.sigma.fill_(1)
137
+
138
+ def estimate_precisions(self):
139
+ return torch.rsqrt(self.sigma)
140
+
141
+ def log_prob(self, x):
142
+ precisions = self.estimate_precisions()
143
+
144
+ x = x.unsqueeze(1)
145
+ mu = self.mu.unsqueeze(0)
146
+ precisions = precisions.unsqueeze(0)
147
+
148
+ # This is outer product
149
+ exp_term = torch.sum(
150
+ (mu * mu + x * x - 2 * x * mu) * (precisions ** 2), dim=2, keepdim=True
151
+ )
152
+ log_det = torch.sum(torch.log(precisions), dim=2, keepdim=True)
153
+
154
+ logp = -0.5 * (self.n_features * torch.log(2 * math.pi) + exp_term) + log_det
155
+
156
+ return logp.squeeze()
157
+
158
+ def estimate_sigma(self, x, mu, pi, responsibilities):
159
+ nk = pi * x.size(0)
160
+ x2 = (responsibilities * x * x).sum(0, keepdim=True) / nk
161
+ mu2 = mu * mu
162
+ xmu = (responsibilities * mu * x).sum(0, keepdim=True) / nk
163
+ sigma = x2 - 2 * xmu + mu2 + self.eps
164
+
165
+ return sigma
166
+
167
+
168
+ def demo1():
169
+ return
170
+
171
+
172
+ if __name__ == '__main__':
173
+ demo1()
toolbox/torch/modules/highway.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import torch.nn as nn
4
+
5
+
6
+ class Highway(nn.Module):
7
+ """
8
+ https://arxiv.org/abs/1505.00387
9
+ [Submitted on 3 May 2015 (v1), last revised 3 Nov 2015 (this version, v2)]
10
+
11
+ discuss of Highway and ResNet
12
+ https://www.zhihu.com/question/279426970
13
+ """
14
+ def __init__(self, in_size, out_size):
15
+ super(Highway, self).__init__()
16
+ self.H = nn.Linear(in_size, out_size)
17
+ self.H.bias.data.zero_()
18
+ self.T = nn.Linear(in_size, out_size)
19
+ self.T.bias.data.fill_(-1)
20
+ self.relu = nn.ReLU()
21
+ self.sigmoid = nn.Sigmoid()
22
+
23
+ def forward(self, inputs):
24
+ H = self.relu(self.H(inputs))
25
+ T = self.sigmoid(self.T(inputs))
26
+ return H * T + inputs * (1.0 - T)
27
+
28
+
29
+ if __name__ == '__main__':
30
+ pass
toolbox/torch/modules/loss.py ADDED
@@ -0,0 +1,738 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import math
4
+ from typing import List, Optional
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.nn.modules.loss import _Loss
11
+ from torch.autograd import Variable
12
+
13
+
14
+ class ClassBalancedLoss(_Loss):
15
+ """
16
+ https://arxiv.org/abs/1901.05555
17
+ """
18
+ @staticmethod
19
+ def demo1():
20
+ batch_loss: torch.FloatTensor = torch.randn(size=(2, 1), dtype=torch.float32)
21
+ targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
22
+
23
+ class_balanced_loss = ClassBalancedLoss(
24
+ num_classes=3,
25
+ num_samples_each_class=[300, 433, 50],
26
+ reduction='mean',
27
+ )
28
+ loss = class_balanced_loss.forward(batch_loss=batch_loss, targets=targets)
29
+ print(loss)
30
+ return
31
+
32
+ @staticmethod
33
+ def demo2():
34
+ inputs: torch.FloatTensor = torch.randn(size=(2, 3), dtype=torch.float32)
35
+ targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
36
+
37
+ focal_loss = FocalLoss(
38
+ num_classes=3,
39
+ # reduction='mean',
40
+ # reduction='sum',
41
+ reduction='none',
42
+ )
43
+ batch_loss = focal_loss.forward(inputs, targets)
44
+ print(batch_loss)
45
+
46
+ class_balanced_loss = ClassBalancedLoss(
47
+ num_classes=3,
48
+ num_samples_each_class=[300, 433, 50],
49
+ reduction='mean',
50
+ )
51
+ loss = class_balanced_loss.forward(batch_loss=batch_loss, targets=targets)
52
+ print(loss)
53
+
54
+ return
55
+
56
+ def __init__(self,
57
+ num_classes: int,
58
+ num_samples_each_class: List[int],
59
+ beta: float = 0.999,
60
+ reduction: str = 'mean') -> None:
61
+ super(ClassBalancedLoss, self).__init__(None, None, reduction)
62
+
63
+ effective_num = 1.0 - np.power(beta, num_samples_each_class)
64
+ weights = (1.0 - beta) / np.array(effective_num)
65
+ self.weights = weights / np.sum(weights) * num_classes
66
+
67
+ def forward(self, batch_loss: torch.FloatTensor, targets: torch.LongTensor):
68
+ """
69
+ :param batch_loss: shape=[batch_size, 1]
70
+ :param targets: shape=[batch_size,]
71
+ :return:
72
+ """
73
+ weights = list()
74
+ targets = targets.numpy()
75
+ for target in targets:
76
+ weights.append([self.weights[target]])
77
+
78
+ weights = torch.tensor(weights, dtype=torch.float32)
79
+ batch_loss = weights * batch_loss
80
+
81
+ if self.reduction == 'mean':
82
+ loss = batch_loss.mean()
83
+ elif self.reduction == 'sum':
84
+ loss = batch_loss.sum()
85
+ else:
86
+ loss = batch_loss
87
+ return loss
88
+
89
+
90
+ class EqualizationLoss(_Loss):
91
+ """
92
+ 在图像识别中的, sigmoid 的多标签分类, 且 num_classes 类别数之外有一个 background 背景类别.
93
+ Equalization Loss
94
+ https://arxiv.org/abs/2003.05176
95
+ Equalization Loss v2
96
+ https://arxiv.org/abs/2012.08548
97
+ """
98
+
99
+ @staticmethod
100
+ def demo1():
101
+ logits: torch.FloatTensor = torch.randn(size=(3, 3), dtype=torch.float32)
102
+ targets: torch.LongTensor = torch.tensor([1, 2, 3], dtype=torch.long)
103
+
104
+ equalization_loss = EqualizationLoss(
105
+ num_samples_each_class=[300, 433, 50],
106
+ threshold=100,
107
+ reduction='mean',
108
+ )
109
+ loss = equalization_loss.forward(logits=logits, targets=targets)
110
+ print(loss)
111
+ return
112
+
113
+ def __init__(self,
114
+ num_samples_each_class: List[int],
115
+ threshold: int = 100,
116
+ reduction: str = 'mean') -> None:
117
+ super(EqualizationLoss, self).__init__(None, None, reduction)
118
+ self.num_samples_each_class = np.array(num_samples_each_class, dtype=np.int32)
119
+ self.threshold = threshold
120
+
121
+ def forward(self,
122
+ logits: torch.FloatTensor,
123
+ targets: torch.LongTensor
124
+ ):
125
+ """
126
+ num_classes + 1 对应于背景类别 background.
127
+ :param logits: shape=[batch_size, num_classes]
128
+ :param targets: shape=[batch_size]
129
+ :return:
130
+ """
131
+ batch_size, num_classes = logits.size()
132
+
133
+ one_hot_targets = F.one_hot(targets, num_classes=num_classes + 1)
134
+ one_hot_targets = one_hot_targets[:, :-1]
135
+
136
+ exclude = self.exclude_func(
137
+ num_classes=num_classes,
138
+ targets=targets
139
+ )
140
+ is_tail = self.threshold_func(
141
+ num_classes=num_classes,
142
+ num_samples_each_class=self.num_samples_each_class,
143
+ threshold=self.threshold,
144
+ )
145
+
146
+ weights = 1 - exclude * is_tail * (1 - one_hot_targets)
147
+
148
+ batch_loss = F.binary_cross_entropy_with_logits(
149
+ logits,
150
+ one_hot_targets.float(),
151
+ reduction='none'
152
+ )
153
+
154
+ batch_loss = weights * batch_loss
155
+
156
+ if self.reduction == 'mean':
157
+ loss = batch_loss.mean()
158
+ elif self.reduction == 'sum':
159
+ loss = batch_loss.sum()
160
+ else:
161
+ loss = batch_loss
162
+
163
+ loss = loss / num_classes
164
+ return loss
165
+
166
+ @staticmethod
167
+ def exclude_func(num_classes: int, targets: torch.LongTensor):
168
+ """
169
+ 最后一个类别是背景 background.
170
+ :param num_classes: int,
171
+ :param targets: shape=[batch_size,]
172
+ :return: weight, shape=[batch_size, num_classes]
173
+ """
174
+ batch_size = targets.shape[0]
175
+ weight = (targets != num_classes).float()
176
+ weight = weight.view(batch_size, 1).expand(batch_size, num_classes)
177
+ return weight
178
+
179
+ @staticmethod
180
+ def threshold_func(num_classes: int, num_samples_each_class: np.ndarray, threshold: int):
181
+ """
182
+ :param num_classes: int,
183
+ :param num_samples_each_class: shape=[num_classes]
184
+ :param threshold: int,
185
+ :return: weight, shape=[1, num_classes]
186
+ """
187
+ weight = torch.zeros(size=(num_classes,))
188
+ weight[num_samples_each_class < threshold] = 1
189
+ weight = torch.unsqueeze(weight, dim=0)
190
+ return weight
191
+
192
+
193
+ class FocalLoss(_Loss):
194
+ """
195
+ https://arxiv.org/abs/1708.02002
196
+ """
197
+ @staticmethod
198
+ def demo1(self):
199
+ inputs: torch.FloatTensor = torch.randn(size=(2, 3), dtype=torch.float32)
200
+ targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
201
+
202
+ focal_loss = FocalLoss(
203
+ num_classes=3,
204
+ reduction='mean',
205
+ # reduction='sum',
206
+ # reduction='none',
207
+ )
208
+ loss = focal_loss.forward(inputs, targets)
209
+ print(loss)
210
+ return
211
+
212
+ def __init__(self,
213
+ num_classes: int,
214
+ alpha: List[float] = None,
215
+ gamma: int = 2,
216
+ reduction: str = 'mean',
217
+ inputs_logits: bool = True) -> None:
218
+ """
219
+ :param num_classes:
220
+ :param alpha:
221
+ :param gamma:
222
+ :param reduction: (`none`, `mean`, `sum`) available.
223
+ :param inputs_logits: if False, the inputs should be probs.
224
+ """
225
+ super(FocalLoss, self).__init__(None, None, reduction)
226
+ if alpha is None:
227
+ self.alpha = torch.ones(num_classes, 1)
228
+ else:
229
+ self.alpha = torch.tensor(alpha, dtype=torch.float32)
230
+ self.gamma = gamma
231
+ self.num_classes = num_classes
232
+ self.inputs_logits = inputs_logits
233
+
234
+ def forward(self,
235
+ inputs: torch.FloatTensor,
236
+ targets: torch.LongTensor):
237
+ """
238
+ :param inputs: logits, shape=[batch_size, num_classes]
239
+ :param targets: shape=[batch_size,]
240
+ :return:
241
+ """
242
+ batch_size, num_classes = inputs.shape
243
+
244
+ if self.inputs_logits:
245
+ probs = F.softmax(inputs, dim=-1)
246
+ else:
247
+ probs = inputs
248
+
249
+ # class_mask = inputs.data.new(batch_size, num_classes).fill_(0)
250
+ class_mask = torch.zeros(size=(batch_size, num_classes), dtype=inputs.dtype, device=inputs.device)
251
+ # class_mask = Variable(class_mask)
252
+ ids = targets.view(-1, 1)
253
+ class_mask.scatter_(1, ids.data, 1.)
254
+
255
+ if inputs.is_cuda and not self.alpha.is_cuda:
256
+ self.alpha = self.alpha.cuda()
257
+ alpha = self.alpha[ids.data.view(-1)]
258
+
259
+ probs = (probs * class_mask).sum(1).view(-1, 1)
260
+
261
+ log_p = probs.log()
262
+
263
+ batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p
264
+
265
+ if self.reduction == 'mean':
266
+ loss = batch_loss.mean()
267
+ elif self.reduction == 'sum':
268
+ loss = batch_loss.sum()
269
+ else:
270
+ loss = batch_loss
271
+ return loss
272
+
273
+
274
+ class HingeLoss(_Loss):
275
+ @staticmethod
276
+ def demo1():
277
+ inputs: torch.FloatTensor = torch.randn(size=(2, 3), dtype=torch.float32)
278
+ targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
279
+
280
+ hinge_loss = HingeLoss(
281
+ margin_list=[300, 433, 50],
282
+ reduction='mean',
283
+ )
284
+ loss = hinge_loss.forward(inputs=inputs, targets=targets)
285
+ print(loss)
286
+ return
287
+
288
+ def __init__(self,
289
+ margin_list: List[float],
290
+ max_margin: float = 0.5,
291
+ scale: float = 1.0,
292
+ weight: Optional[torch.Tensor] = None,
293
+ reduction: str = 'mean') -> None:
294
+ super(HingeLoss, self).__init__(None, None, reduction)
295
+
296
+ self.max_margin = max_margin
297
+ self.scale = scale
298
+ self.weight = weight
299
+
300
+ margin_list = np.array(margin_list)
301
+ margin_list = margin_list * (max_margin / np.max(margin_list))
302
+ self.margin_list = torch.tensor(margin_list, dtype=torch.float32)
303
+
304
+ def forward(self,
305
+ inputs: torch.FloatTensor,
306
+ targets: torch.LongTensor
307
+ ):
308
+ """
309
+ :param inputs: logits, shape=[batch_size, num_classes]
310
+ :param targets: shape=[batch_size,]
311
+ :return:
312
+ """
313
+ batch_size, num_classes = inputs.shape
314
+ one_hot_targets = F.one_hot(targets, num_classes=num_classes)
315
+ margin_list = torch.unsqueeze(self.margin_list, dim=0)
316
+
317
+ batch_margin = torch.sum(margin_list * one_hot_targets, dim=-1)
318
+ batch_margin = torch.unsqueeze(batch_margin, dim=-1)
319
+ inputs_margin = inputs - batch_margin
320
+
321
+ # 将类别对应的 logits 值减小一点, 以形成 margin 边界.
322
+ logits = torch.where(one_hot_targets > 0, inputs_margin, inputs)
323
+
324
+ loss = F.cross_entropy(
325
+ input=self.scale * logits,
326
+ target=targets,
327
+ weight=self.weight,
328
+ reduction=self.reduction,
329
+ )
330
+ return loss
331
+
332
+
333
+ class HingeLinear(nn.Module):
334
+ """
335
+ use this instead of `HingeLoss`, then you can combine it with `FocalLoss` or others.
336
+ """
337
+ def __init__(self,
338
+ margin_list: List[float],
339
+ max_margin: float = 0.5,
340
+ scale: float = 1.0,
341
+ weight: Optional[torch.Tensor] = None
342
+ ) -> None:
343
+ super(HingeLinear, self).__init__()
344
+
345
+ self.max_margin = max_margin
346
+ self.scale = scale
347
+ self.weight = weight
348
+
349
+ margin_list = np.array(margin_list)
350
+ margin_list = margin_list * (max_margin / np.max(margin_list))
351
+ self.margin_list = torch.tensor(margin_list, dtype=torch.float32)
352
+
353
+ def forward(self,
354
+ inputs: torch.FloatTensor,
355
+ targets: torch.LongTensor
356
+ ):
357
+ """
358
+ :param inputs: logits, shape=[batch_size, num_classes]
359
+ :param targets: shape=[batch_size,]
360
+ :return:
361
+ """
362
+ if self.training and targets is not None:
363
+ batch_size, num_classes = inputs.shape
364
+ one_hot_targets = F.one_hot(targets, num_classes=num_classes)
365
+ margin_list = torch.unsqueeze(self.margin_list, dim=0)
366
+
367
+ batch_margin = torch.sum(margin_list * one_hot_targets, dim=-1)
368
+ batch_margin = torch.unsqueeze(batch_margin, dim=-1)
369
+ inputs_margin = inputs - batch_margin
370
+
371
+ # 将类别对应的 logits 值减小一点, 以形成 margin 边界.
372
+ logits = torch.where(one_hot_targets > 0, inputs_margin, inputs)
373
+ logits = logits * self.scale
374
+ else:
375
+ logits = inputs
376
+ return logits
377
+
378
+
379
+ class LDAMLoss(_Loss):
380
+ """
381
+ https://arxiv.org/abs/1906.07413
382
+ """
383
+ @staticmethod
384
+ def demo1():
385
+ inputs: torch.FloatTensor = torch.randn(size=(2, 3), dtype=torch.float32)
386
+ targets: torch.LongTensor = torch.tensor([1, 2], dtype=torch.long)
387
+
388
+ ldam_loss = LDAMLoss(
389
+ num_samples_each_class=[300, 433, 50],
390
+ reduction='mean',
391
+ )
392
+ loss = ldam_loss.forward(inputs=inputs, targets=targets)
393
+ print(loss)
394
+ return
395
+
396
+ def __init__(self,
397
+ num_samples_each_class: List[int],
398
+ max_margin: float = 0.5,
399
+ scale: float = 30.0,
400
+ weight: Optional[torch.Tensor] = None,
401
+ reduction: str = 'mean') -> None:
402
+ super(LDAMLoss, self).__init__(None, None, reduction)
403
+
404
+ margin_list = np.power(num_samples_each_class, -0.25)
405
+ margin_list = margin_list * (max_margin / np.max(margin_list))
406
+
407
+ self.num_samples_each_class = num_samples_each_class
408
+ self.margin_list = torch.tensor(margin_list, dtype=torch.float32)
409
+ self.scale = scale
410
+ self.weight = weight
411
+
412
+ def forward(self,
413
+ inputs: torch.FloatTensor,
414
+ targets: torch.LongTensor
415
+ ):
416
+ """
417
+ :param inputs: logits, shape=[batch_size, num_classes]
418
+ :param targets: shape=[batch_size,]
419
+ :return:
420
+ """
421
+ batch_size, num_classes = inputs.shape
422
+ one_hot_targets = F.one_hot(targets, num_classes=num_classes)
423
+ margin_list = torch.unsqueeze(self.margin_list, dim=0)
424
+
425
+ batch_margin = torch.sum(margin_list * one_hot_targets, dim=-1)
426
+ batch_margin = torch.unsqueeze(batch_margin, dim=-1)
427
+ inputs_margin = inputs - batch_margin
428
+
429
+ # 将类别对应的 logits 值减小一点, 以形成 margin 边界.
430
+ logits = torch.where(one_hot_targets > 0, inputs_margin, inputs)
431
+
432
+ loss = F.cross_entropy(
433
+ input=self.scale * logits,
434
+ target=targets,
435
+ weight=self.weight,
436
+ reduction=self.reduction,
437
+ )
438
+ return loss
439
+
440
+
441
+ class NegativeEntropy(_Loss):
442
+ def __init__(self,
443
+ reduction: str = 'mean',
444
+ inputs_logits: bool = True) -> None:
445
+ super(NegativeEntropy, self).__init__(None, None, reduction)
446
+ self.inputs_logits = inputs_logits
447
+
448
+ def forward(self,
449
+ inputs: torch.FloatTensor,
450
+ targets: torch.LongTensor):
451
+ if self.inputs_logits:
452
+ probs = F.softmax(inputs, dim=-1)
453
+ log_probs = torch.nn.functional.log_softmax(probs, dim=-1)
454
+ else:
455
+ probs = inputs
456
+ log_probs = torch.log(probs)
457
+
458
+ weighted_negative_likelihood = - log_probs * probs
459
+
460
+ loss = - weighted_negative_likelihood.sum()
461
+ return loss
462
+
463
+
464
+ class LargeMarginSoftMaxLoss(_Loss):
465
+ """
466
+ Alias: L-Softmax
467
+
468
+ https://arxiv.org/abs/1612.02295
469
+ https://github.com/wy1iu/LargeMargin_Softmax_Loss
470
+ https://github.com/amirhfarzaneh/lsoftmax-pytorch/blob/master/lsoftmax.py
471
+
472
+ 参考链接:
473
+ https://www.jianshu.com/p/06cc3f84aa85
474
+
475
+ 论文认为, softmax 和 cross entropy 的组合, 没有明确鼓励对特征进行判别学习.
476
+
477
+ """
478
+ def __init__(self,
479
+ reduction: str = 'mean') -> None:
480
+ super(LargeMarginSoftMaxLoss, self).__init__(None, None, reduction)
481
+
482
+
483
+ class AngularSoftMaxLoss(_Loss):
484
+ """
485
+ Alias: A-Softmax
486
+
487
+ https://arxiv.org/abs/1704.08063
488
+
489
+ https://github.com/woshildh/a-softmax_pytorch/blob/master/a_softmax.py
490
+
491
+ 参考链接:
492
+ https://www.jianshu.com/p/06cc3f84aa85
493
+
494
+ 好像作者认为人脸是一个球面, 所以将向量转换到一个球面上是有帮助的.
495
+ """
496
+ def __init__(self,
497
+ reduction: str = 'mean') -> None:
498
+ super(AngularSoftMaxLoss, self).__init__(None, None, reduction)
499
+
500
+
501
+ class AdditiveMarginSoftMax(_Loss):
502
+ """
503
+ Alias: AM-Softmax
504
+
505
+ https://arxiv.org/abs/1801.05599
506
+
507
+ Large Margin Cosine Loss
508
+ https://arxiv.org/abs/1801.09414
509
+
510
+ 参考链接:
511
+ https://www.jianshu.com/p/06cc3f84aa85
512
+
513
+ 说明:
514
+ 相对于普通的 对 logits 做 softmax,
515
+ 它将真实标签对应的 logit 值减去 m, 来让模型它该值调整得更大一些.
516
+ 另外, 它还将每个 logits 乘以 s, 这可以控制各 logits 之间的相对大小.
517
+ 根 HingeLoss 有点像.
518
+ """
519
+ def __init__(self,
520
+ reduction: str = 'mean') -> None:
521
+ super(AdditiveMarginSoftMax, self).__init__(None, None, reduction)
522
+
523
+
524
+ class AdditiveAngularMarginSoftMax(_Loss):
525
+ """
526
+ Alias: ArcFace, AAM-Softmax
527
+
528
+ ArcFace: Additive Angular Margin Loss for Deep Face Recognition
529
+ https://arxiv.org/abs/1801.07698
530
+
531
+ 参考代码:
532
+ https://github.com/huangkeju/AAMSoftmax-OpenMax/blob/main/AAMSoftmax%2BOvA/metrics.py
533
+
534
+ """
535
+ @staticmethod
536
+ def demo1():
537
+ """
538
+ 角度与数值转换
539
+ pi / 180 代表 1 度,
540
+ pi / 180 = 0.01745
541
+ """
542
+
543
+ # 度数转数值
544
+ degree = 10
545
+ result = degree * math.pi / 180
546
+ print(result)
547
+
548
+ # 数值转数度
549
+ radian = 0.2
550
+ result = radian / (math.pi / 180)
551
+ print(result)
552
+
553
+ return
554
+
555
+ def __init__(self,
556
+ hidden_size: int,
557
+ num_labels: int,
558
+ margin: float = 0.2,
559
+ scale: float = 10.0,
560
+ ):
561
+ """
562
+ :param hidden_size:
563
+ :param num_labels:
564
+ :param margin: 建议取值角度为 [10, 30], 对应的数值为 [0.1745, 0.5236]
565
+ :param scale:
566
+ """
567
+ super(AdditiveAngularMarginSoftMax, self).__init__()
568
+ self.margin = margin
569
+ self.scale = scale
570
+ self.weight = torch.nn.Parameter(torch.FloatTensor(num_labels, hidden_size), requires_grad=True)
571
+ nn.init.xavier_uniform_(self.weight)
572
+
573
+ self.cos_margin = math.cos(self.margin)
574
+ self.sin_margin = math.sin(self.margin)
575
+
576
+ # sin(a-b) = sin(a)cos(b) - cos(a)sin(b)
577
+ # sin(pi - a) = sin(a)
578
+
579
+ self.loss = nn.CrossEntropyLoss()
580
+
581
+ def forward(self,
582
+ inputs: torch.Tensor,
583
+ label: torch.LongTensor = None
584
+ ):
585
+ """
586
+ :param inputs: shape=[batch_size, ..., hidden_size]
587
+ :param label:
588
+ :return: logits
589
+ """
590
+ x = F.normalize(inputs)
591
+ weight = F.normalize(self.weight)
592
+ cosine = F.linear(x, weight)
593
+
594
+ if self.training:
595
+
596
+ # sin^2 + cos^2 = 1
597
+ sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1))
598
+
599
+ # cos(a+b) = cos(a)cos(b) - sin(a)sin(b)
600
+ cosine_theta_margin = cosine * self.cos_margin - sine * self.sin_margin
601
+
602
+ # when the `cosine > - self.cos_margin` there is enough space to add margin on theta.
603
+ cosine_theta_margin = torch.where(cosine > - self.cos_margin, cosine_theta_margin, cosine - (self.margin * self.sin_margin))
604
+
605
+ one_hot = torch.zeros_like(cosine)
606
+ one_hot.scatter_(1, label.view(-1, 1), 1)
607
+
608
+ #
609
+ logits = torch.where(one_hot == 1, cosine_theta_margin, cosine)
610
+ logits = logits * self.scale
611
+ else:
612
+ logits = cosine
613
+
614
+ loss = self.loss(logits, label)
615
+ # prec1 = accuracy(output.detach(), label.detach(), topk=(1,))[0]
616
+ return loss
617
+
618
+
619
+ class AdditiveAngularMarginLinear(nn.Module):
620
+ """
621
+ Alias: ArcFace, AAM-Softmax
622
+
623
+ ArcFace: Additive Angular Margin Loss for Deep Face Recognition
624
+ https://arxiv.org/abs/1801.07698
625
+
626
+ 参考代码:
627
+ https://github.com/huangkeju/AAMSoftmax-OpenMax/blob/main/AAMSoftmax%2BOvA/metrics.py
628
+
629
+ """
630
+ @staticmethod
631
+ def demo1():
632
+ """
633
+ 角度与数值转换
634
+ pi / 180 代表 1 度,
635
+ pi / 180 = 0.01745
636
+ """
637
+
638
+ # 度数转数值
639
+ degree = 10
640
+ result = degree * math.pi / 180
641
+ print(result)
642
+
643
+ # 数值转数度
644
+ radian = 0.2
645
+ result = radian / (math.pi / 180)
646
+ print(result)
647
+
648
+ return
649
+
650
+ @staticmethod
651
+ def demo2():
652
+
653
+ return
654
+
655
+ def __init__(self,
656
+ hidden_size: int,
657
+ num_labels: int,
658
+ margin: float = 0.2,
659
+ scale: float = 10.0,
660
+ ):
661
+ """
662
+ :param hidden_size:
663
+ :param num_labels:
664
+ :param margin: 建议取值角度为 [10, 30], 对应的数值为 [0.1745, 0.5236]
665
+ :param scale:
666
+ """
667
+ super(AdditiveAngularMarginLinear, self).__init__()
668
+ self.margin = margin
669
+ self.scale = scale
670
+ self.weight = torch.nn.Parameter(torch.FloatTensor(num_labels, hidden_size), requires_grad=True)
671
+ nn.init.xavier_uniform_(self.weight)
672
+
673
+ self.cos_margin = math.cos(self.margin)
674
+ self.sin_margin = math.sin(self.margin)
675
+
676
+ # sin(a-b) = sin(a)cos(b) - cos(a)sin(b)
677
+ # sin(pi - a) = sin(a)
678
+
679
+ def forward(self,
680
+ inputs: torch.Tensor,
681
+ targets: torch.LongTensor = None
682
+ ):
683
+ """
684
+ :param inputs: shape=[batch_size, ..., hidden_size]
685
+ :param targets:
686
+ :return: logits
687
+ """
688
+ x = F.normalize(inputs)
689
+ weight = F.normalize(self.weight)
690
+ cosine = F.linear(x, weight)
691
+
692
+ if self.training and targets is not None:
693
+ # sin^2 + cos^2 = 1
694
+ sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1))
695
+
696
+ # cos(a+b) = cos(a)cos(b) - sin(a)sin(b)
697
+ cosine_theta_margin = cosine * self.cos_margin - sine * self.sin_margin
698
+
699
+ # when the `cosine > - self.cos_margin` there is enough space to add margin on theta.
700
+ cosine_theta_margin = torch.where(cosine > - self.cos_margin, cosine_theta_margin, cosine - (self.margin * self.sin_margin))
701
+
702
+ one_hot = torch.zeros_like(cosine)
703
+ one_hot.scatter_(1, targets.view(-1, 1), 1)
704
+
705
+ logits = torch.where(one_hot == 1, cosine_theta_margin, cosine)
706
+ logits = logits * self.scale
707
+ else:
708
+ logits = cosine
709
+ return logits
710
+
711
+
712
+ def demo1():
713
+ HingeLoss.demo1()
714
+ return
715
+
716
+
717
+ def demo2():
718
+ AdditiveAngularMarginSoftMax.demo1()
719
+
720
+ inputs = torch.ones(size=(2, 5), dtype=torch.float32)
721
+ label: torch.LongTensor = torch.tensor(data=[0, 1], dtype=torch.long)
722
+
723
+ aam_softmax = AdditiveAngularMarginSoftMax(
724
+ hidden_size=5,
725
+ num_labels=2,
726
+ margin=1,
727
+ scale=1
728
+ )
729
+
730
+ outputs = aam_softmax.forward(inputs, label)
731
+ print(outputs)
732
+
733
+ return
734
+
735
+
736
+ if __name__ == '__main__':
737
+ # demo1()
738
+ demo2()
toolbox/torch/training/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torch/training/metrics/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torch/training/metrics/categorical_accuracy.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from overrides import overrides
4
+ import torch
5
+
6
+
7
+ class CategoricalAccuracy(object):
8
+ def __init__(self, top_k: int = 1, tie_break: bool = False) -> None:
9
+ if top_k > 1 and tie_break:
10
+ raise AssertionError("Tie break in Categorical Accuracy "
11
+ "can be done only for maximum (top_k = 1)")
12
+ if top_k <= 0:
13
+ raise AssertionError("top_k passed to Categorical Accuracy must be > 0")
14
+ self._top_k = top_k
15
+ self._tie_break = tie_break
16
+ self.correct_count = 0.
17
+ self.total_count = 0.
18
+
19
+ def __call__(self,
20
+ predictions: torch.Tensor,
21
+ gold_labels: torch.Tensor,
22
+ mask: Optional[torch.Tensor] = None):
23
+
24
+ # predictions, gold_labels, mask = self.unwrap_to_tensors(predictions, gold_labels, mask)
25
+
26
+ # Some sanity checks.
27
+ num_classes = predictions.size(-1)
28
+ if gold_labels.dim() != predictions.dim() - 1:
29
+ raise AssertionError("gold_labels must have dimension == predictions.size() - 1 but "
30
+ "found tensor of shape: {}".format(predictions.size()))
31
+ if (gold_labels >= num_classes).any():
32
+ raise AssertionError("A gold label passed to Categorical Accuracy contains an id >= {}, "
33
+ "the number of classes.".format(num_classes))
34
+
35
+ predictions = predictions.view((-1, num_classes))
36
+ gold_labels = gold_labels.view(-1).long()
37
+ if not self._tie_break:
38
+ # Top K indexes of the predictions (or fewer, if there aren't K of them).
39
+ # Special case topk == 1, because it's common and .max() is much faster than .topk().
40
+ if self._top_k == 1:
41
+ top_k = predictions.max(-1)[1].unsqueeze(-1)
42
+ else:
43
+ top_k = predictions.topk(min(self._top_k, predictions.shape[-1]), -1)[1]
44
+
45
+ # This is of shape (batch_size, ..., top_k).
46
+ correct = top_k.eq(gold_labels.unsqueeze(-1)).float()
47
+ else:
48
+ # prediction is correct if gold label falls on any of the max scores. distribute score by tie_counts
49
+ max_predictions = predictions.max(-1)[0]
50
+ max_predictions_mask = predictions.eq(max_predictions.unsqueeze(-1))
51
+ # max_predictions_mask is (rows X num_classes) and gold_labels is (batch_size)
52
+ # ith entry in gold_labels points to index (0-num_classes) for ith row in max_predictions
53
+ # For each row check if index pointed by gold_label is was 1 or not (among max scored classes)
54
+ correct = max_predictions_mask[torch.arange(gold_labels.numel()).long(), gold_labels].float()
55
+ tie_counts = max_predictions_mask.sum(-1)
56
+ correct /= tie_counts.float()
57
+ correct.unsqueeze_(-1)
58
+
59
+ if mask is not None:
60
+ correct *= mask.view(-1, 1).float()
61
+ self.total_count += mask.sum()
62
+ else:
63
+ self.total_count += gold_labels.numel()
64
+ self.correct_count += correct.sum()
65
+
66
+ def get_metric(self, reset: bool = False):
67
+ """
68
+ Returns
69
+ -------
70
+ The accumulated accuracy.
71
+ """
72
+ if self.total_count > 1e-12:
73
+ accuracy = float(self.correct_count) / float(self.total_count)
74
+ else:
75
+ accuracy = 0.0
76
+ if reset:
77
+ self.reset()
78
+ return {'accuracy': accuracy}
79
+
80
+ def reset(self):
81
+ self.correct_count = 0.0
82
+ self.total_count = 0.0
toolbox/torch/training/metrics/verbose_categorical_accuracy.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Dict, List, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+
9
+ class CategoricalAccuracyVerbose(object):
10
+ def __init__(self,
11
+ index_to_token: Dict[int, str],
12
+ label_namespace: str = "labels",
13
+ top_k: int = 1,
14
+ ) -> None:
15
+ if top_k <= 0:
16
+ raise AssertionError("top_k passed to Categorical Accuracy must be > 0")
17
+ self._index_to_token = index_to_token
18
+ self._label_namespace = label_namespace
19
+ self._top_k = top_k
20
+ self.correct_count = 0.
21
+ self.total_count = 0.
22
+ self.label_correct_count = dict()
23
+ self.label_total_count = dict()
24
+
25
+ def __call__(self,
26
+ predictions: torch.Tensor,
27
+ gold_labels: torch.Tensor,
28
+ mask: Optional[torch.Tensor] = None):
29
+ num_classes = predictions.size(-1)
30
+ if gold_labels.dim() != predictions.dim() - 1:
31
+ raise AssertionError("gold_labels must have dimension == predictions.size() - 1 but "
32
+ "found tensor of shape: {}".format(predictions.size()))
33
+ if (gold_labels >= num_classes).any():
34
+ raise AssertionError("A gold label passed to Categorical Accuracy contains an id >= {}, "
35
+ "the number of classes.".format(num_classes))
36
+
37
+ predictions = predictions.view((-1, num_classes))
38
+ gold_labels = gold_labels.view(-1).long()
39
+
40
+ # Top K indexes of the predictions (or fewer, if there aren't K of them).
41
+ # Special case topk == 1, because it's common and .max() is much faster than .topk().
42
+ if self._top_k == 1:
43
+ top_k = predictions.max(-1)[1].unsqueeze(-1)
44
+ else:
45
+ top_k = predictions.topk(min(self._top_k, predictions.shape[-1]), -1)[1]
46
+
47
+ # This is of shape (batch_size, ..., top_k).
48
+ correct = top_k.eq(gold_labels.unsqueeze(-1)).float()
49
+
50
+ if mask is not None:
51
+ correct *= mask.view(-1, 1).float()
52
+ self.total_count += mask.sum()
53
+ else:
54
+ self.total_count += gold_labels.numel()
55
+ self.correct_count += correct.sum()
56
+
57
+ labels: List[int] = np.unique(gold_labels.cpu().numpy()).tolist()
58
+ for label in labels:
59
+ label_mask = (gold_labels == label)
60
+
61
+ label_correct = correct * label_mask.view(-1, 1).float()
62
+ label_correct = int(label_correct.sum())
63
+ label_count = int(label_mask.sum())
64
+
65
+ label_str = self._index_to_token[label]
66
+ if label_str in self.label_correct_count:
67
+ self.label_correct_count[label_str] += label_correct
68
+ else:
69
+ self.label_correct_count[label_str] = label_correct
70
+
71
+ if label_str in self.label_total_count:
72
+ self.label_total_count[label_str] += label_count
73
+ else:
74
+ self.label_total_count[label_str] = label_count
75
+
76
+ def get_metric(self, reset: bool = False):
77
+ """
78
+ Returns
79
+ -------
80
+ The accumulated accuracy.
81
+ """
82
+ result = dict()
83
+ if self.total_count > 1e-12:
84
+ accuracy = float(self.correct_count) / float(self.total_count)
85
+ else:
86
+ accuracy = 0.0
87
+ result['accuracy'] = accuracy
88
+
89
+ for label in self.label_total_count.keys():
90
+ total = self.label_total_count[label]
91
+ correct = self.label_correct_count.get(label, 0.0)
92
+ label_accuracy = correct / total
93
+ result[label] = label_accuracy
94
+
95
+ if reset:
96
+ self.reset()
97
+ return result
98
+
99
+ def reset(self):
100
+ self.correct_count = 0.0
101
+ self.total_count = 0.0
102
+ self.label_correct_count = dict()
103
+ self.label_total_count = dict()
104
+
105
+
106
+ def demo1():
107
+
108
+ categorical_accuracy_verbose = CategoricalAccuracyVerbose(
109
+ index_to_token={0: '0', 1: '1'},
110
+ top_k=2,
111
+ )
112
+
113
+ predictions = torch.randn(size=(2, 3), dtype=torch.float32)
114
+ gold_labels = torch.ones(size=(2,), dtype=torch.long)
115
+ # print(predictions)
116
+ # print(gold_labels)
117
+
118
+ categorical_accuracy_verbose(
119
+ predictions=predictions,
120
+ gold_labels=gold_labels,
121
+ )
122
+ metric = categorical_accuracy_verbose.get_metric()
123
+ print(metric)
124
+ return
125
+
126
+
127
+ if __name__ == '__main__':
128
+ demo1()
toolbox/torch/training/trainer/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == '__main__':
5
+ pass
toolbox/torch/training/trainer/trainer.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == '__main__':
5
+ pass
toolbox/torch/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == '__main__':
5
+ pass
toolbox/torchaudio/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == '__main__':
5
+ pass
toolbox/torchaudio/configuration_utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import copy
4
+ import os
5
+ from typing import Any, Dict, Union
6
+
7
+ import yaml
8
+
9
+
10
+ CONFIG_FILE = "config.yaml"
11
+
12
+
13
+ class PretrainedConfig(object):
14
+ def __init__(self, **kwargs):
15
+ pass
16
+
17
+ @classmethod
18
+ def _dict_from_yaml_file(cls, yaml_file: Union[str, os.PathLike]):
19
+ with open(yaml_file, encoding="utf-8") as f:
20
+ config_dict = yaml.safe_load(f)
21
+ return config_dict
22
+
23
+ @classmethod
24
+ def get_config_dict(
25
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike]
26
+ ) -> Dict[str, Any]:
27
+ if os.path.isdir(pretrained_model_name_or_path):
28
+ config_file = os.path.join(pretrained_model_name_or_path, CONFIG_FILE)
29
+ else:
30
+ config_file = pretrained_model_name_or_path
31
+ config_dict = cls._dict_from_yaml_file(config_file)
32
+ return config_dict
33
+
34
+ @classmethod
35
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
36
+ for k, v in kwargs.items():
37
+ if k in config_dict.keys():
38
+ config_dict[k] = v
39
+ config = cls(**config_dict)
40
+ return config
41
+
42
+ @classmethod
43
+ def from_pretrained(
44
+ cls,
45
+ pretrained_model_name_or_path: Union[str, os.PathLike],
46
+ **kwargs,
47
+ ):
48
+ config_dict = cls.get_config_dict(pretrained_model_name_or_path)
49
+ return cls.from_dict(config_dict, **kwargs)
50
+
51
+ def to_dict(self):
52
+ output = copy.deepcopy(self.__dict__)
53
+ return output
54
+
55
+ def to_yaml_file(self, yaml_file_path: Union[str, os.PathLike]):
56
+ config_dict = self.to_dict()
57
+
58
+ with open(yaml_file_path, "w", encoding="utf-8") as writer:
59
+ yaml.safe_dump(config_dict, writer)
60
+
61
+
62
+ if __name__ == '__main__':
63
+ pass