HoneyTian commited on
Commit
f0a00b6
1 Parent(s): 69ad385
.gitignore CHANGED
@@ -8,11 +8,11 @@
8
  **/logs/
9
  **/__pycache__/
10
 
11
- data/
12
- docs/
13
- dotenv/
14
- trained_models/
15
- temp/
16
 
17
  #**/*.wav
18
  **/*.xlsx
 
8
  **/logs/
9
  **/__pycache__/
10
 
11
+ /data/
12
+ /docs/
13
+ /dotenv/
14
+ /trained_models/
15
+ /temp/
16
 
17
  #**/*.wav
18
  **/*.xlsx
examples/vm_sound_classification/run.sh CHANGED
@@ -13,7 +13,7 @@ E:/programmer/asr_datasets/voicemail/wav_finished/id-ID/wav_finished/*/*.wav" \
13
  sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification3
14
  sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification3
15
 
16
- sh run.sh --stage 3 --stop_stage 5 --system_version centos --file_folder_name file_dir --final_model_name vm_sound_classification8-ch16 \
17
  --filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
18
 
19
 
 
13
  sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification3
14
  sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification3
15
 
16
+ sh run.sh --stage 0 --stop_stage 1 --system_version centos --file_folder_name file_dir --final_model_name vm_sound_classification8-ch16 \
17
  --filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
18
 
19
 
toolbox/torch/utils/data/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == '__main__':
5
+ pass
toolbox/torch/utils/data/dataset/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == '__main__':
5
+ pass
toolbox/torch/utils/data/dataset/wave_classifier_excel_dataset.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+
5
+ import librosa
6
+ import numpy as np
7
+ import pandas as pd
8
+ from scipy.io import wavfile
9
+ import torch
10
+ import torchaudio
11
+ from torch.utils.data import Dataset
12
+ from tqdm import tqdm
13
+
14
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
15
+
16
+
17
+ class WaveClassifierExcelDataset(Dataset):
18
+ def __init__(self,
19
+ vocab: Vocabulary,
20
+ excel_file: str,
21
+ expected_sample_rate: int,
22
+ resample: bool = False,
23
+ root_path: str = None,
24
+ category: str = None,
25
+ category_field: str = "category",
26
+ label_field: str = "labels",
27
+ max_wave_value: float = 1.0,
28
+ ) -> None:
29
+ self.vocab = vocab
30
+ self.excel_file = excel_file
31
+
32
+ self.expected_sample_rate = expected_sample_rate
33
+ self.resample = resample
34
+ self.root_path = root_path
35
+ self.category = category
36
+ self.category_field = category_field
37
+ self.label_field = label_field
38
+ self.max_wave_value = max_wave_value
39
+
40
+ df = pd.read_excel(excel_file)
41
+
42
+ samples = list()
43
+ for i, row in tqdm(df.iterrows(), total=len(df)):
44
+ filename = row["filename"]
45
+ label = row[self.label_field]
46
+
47
+ if self.category is not None and self.category != row[self.category_field]:
48
+ continue
49
+
50
+ samples.append({
51
+ "filename": filename,
52
+ "label": label,
53
+ })
54
+ self.samples = samples
55
+
56
+ def __getitem__(self, index):
57
+ sample = self.samples[index]
58
+ filename = sample["filename"]
59
+ label = sample["label"]
60
+
61
+ if self.root_path is not None:
62
+ filename = os.path.join(self.root_path, filename)
63
+
64
+ waveform = self.filename_to_waveform(filename)
65
+
66
+ namespace = self.label_field if self.category is None else self.category
67
+ token_to_index = self.vocab.get_token_to_index_vocabulary(namespace=namespace)
68
+ label: int = token_to_index[label]
69
+
70
+ result = {
71
+ "waveform": waveform,
72
+ "label": torch.tensor(label, dtype=torch.int64),
73
+ }
74
+ return result
75
+
76
+ def __len__(self):
77
+ return len(self.samples)
78
+
79
+ def filename_to_waveform(self, filename: str):
80
+ try:
81
+ if self.resample:
82
+ waveform, sample_rate = librosa.load(filename, sr=self.expected_sample_rate)
83
+ # waveform, sample_rate = torchaudio.load(filename, normalize=True)
84
+ else:
85
+ sample_rate, waveform = wavfile.read(filename)
86
+ waveform = waveform / self.max_wave_value
87
+ except ValueError as e:
88
+ print(filename)
89
+ raise e
90
+ if sample_rate != self.expected_sample_rate:
91
+ raise AssertionError
92
+
93
+ waveform = torch.tensor(waveform, dtype=torch.float32)
94
+ return waveform
95
+
96
+
97
+ if __name__ == "__main__":
98
+ pass
toolbox/torch/utils/data/vocabulary.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from collections import defaultdict, OrderedDict
4
+ import os
5
+ from typing import Any, Callable, Dict, Iterable, List, Set
6
+
7
+
8
+ def namespace_match(pattern: str, namespace: str):
9
+ """
10
+ Matches a namespace pattern against a namespace string. For example, ``*tags`` matches
11
+ ``passage_tags`` and ``question_tags`` and ``tokens`` matches ``tokens`` but not
12
+ ``stemmed_tokens``.
13
+ """
14
+ if pattern[0] == '*' and namespace.endswith(pattern[1:]):
15
+ return True
16
+ elif pattern == namespace:
17
+ return True
18
+ return False
19
+
20
+
21
+ class _NamespaceDependentDefaultDict(defaultdict):
22
+ def __init__(self,
23
+ non_padded_namespaces: Set[str],
24
+ padded_function: Callable[[], Any],
25
+ non_padded_function: Callable[[], Any]) -> None:
26
+ self._non_padded_namespaces = set(non_padded_namespaces)
27
+ self._padded_function = padded_function
28
+ self._non_padded_function = non_padded_function
29
+ super(_NamespaceDependentDefaultDict, self).__init__()
30
+
31
+ def __missing__(self, key: str):
32
+ if any(namespace_match(pattern, key) for pattern in self._non_padded_namespaces):
33
+ value = self._non_padded_function()
34
+ else:
35
+ value = self._padded_function()
36
+ dict.__setitem__(self, key, value)
37
+ return value
38
+
39
+ def add_non_padded_namespaces(self, non_padded_namespaces: Set[str]):
40
+ # add non_padded_namespaces which weren't already present
41
+ self._non_padded_namespaces.update(non_padded_namespaces)
42
+
43
+
44
+ class _TokenToIndexDefaultDict(_NamespaceDependentDefaultDict):
45
+ def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
46
+ super(_TokenToIndexDefaultDict, self).__init__(non_padded_namespaces,
47
+ lambda: {padding_token: 0, oov_token: 1},
48
+ lambda: {})
49
+
50
+
51
+ class _IndexToTokenDefaultDict(_NamespaceDependentDefaultDict):
52
+ def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
53
+ super(_IndexToTokenDefaultDict, self).__init__(non_padded_namespaces,
54
+ lambda: {0: padding_token, 1: oov_token},
55
+ lambda: {})
56
+
57
+
58
+ DEFAULT_NON_PADDED_NAMESPACES = ("*tags", "*labels")
59
+ DEFAULT_PADDING_TOKEN = '[PAD]'
60
+ DEFAULT_OOV_TOKEN = '[UNK]'
61
+ NAMESPACE_PADDING_FILE = 'non_padded_namespaces.txt'
62
+
63
+
64
+ class Vocabulary(object):
65
+ def __init__(self, non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES):
66
+ self._non_padded_namespaces = set(non_padded_namespaces)
67
+ self._padding_token = DEFAULT_PADDING_TOKEN
68
+ self._oov_token = DEFAULT_OOV_TOKEN
69
+ self._token_to_index = _TokenToIndexDefaultDict(self._non_padded_namespaces,
70
+ self._padding_token,
71
+ self._oov_token)
72
+ self._index_to_token = _IndexToTokenDefaultDict(self._non_padded_namespaces,
73
+ self._padding_token,
74
+ self._oov_token)
75
+
76
+ def add_token_to_namespace(self, token: str, namespace: str = 'tokens') -> int:
77
+ if token not in self._token_to_index[namespace]:
78
+ index = len(self._token_to_index[namespace])
79
+ self._token_to_index[namespace][token] = index
80
+ self._index_to_token[namespace][index] = token
81
+ return index
82
+ else:
83
+ return self._token_to_index[namespace][token]
84
+
85
+ def get_index_to_token_vocabulary(self, namespace: str = 'tokens') -> Dict[int, str]:
86
+ return self._index_to_token[namespace]
87
+
88
+ def get_token_to_index_vocabulary(self, namespace: str = 'tokens') -> Dict[str, int]:
89
+ return self._token_to_index[namespace]
90
+
91
+ def get_token_index(self, token: str, namespace: str = 'tokens') -> int:
92
+ if token in self._token_to_index[namespace]:
93
+ return self._token_to_index[namespace][token]
94
+ else:
95
+ return self._token_to_index[namespace][self._oov_token]
96
+
97
+ def get_token_from_index(self, index: int, namespace: str = 'tokens'):
98
+ return self._index_to_token[namespace][index]
99
+
100
+ def get_vocab_size(self, namespace: str = 'tokens') -> int:
101
+ return len(self._token_to_index[namespace])
102
+
103
+ def save_to_files(self, directory: str):
104
+ os.makedirs(directory, exist_ok=True)
105
+ with open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'w', encoding='utf-8') as f:
106
+ for namespace_str in self._non_padded_namespaces:
107
+ f.write('{}\n'.format(namespace_str))
108
+
109
+ for namespace, token_to_index in self._token_to_index.items():
110
+ filename = os.path.join(directory, '{}.txt'.format(namespace))
111
+ with open(filename, 'w', encoding='utf-8') as f:
112
+ for token, _ in token_to_index.items():
113
+ f.write('{}\n'.format(token))
114
+
115
+ @classmethod
116
+ def from_files(cls, directory: str) -> 'Vocabulary':
117
+ with open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'r', encoding='utf-8') as f:
118
+ non_padded_namespaces = [namespace_str.strip() for namespace_str in f]
119
+
120
+ vocab = cls(non_padded_namespaces=non_padded_namespaces)
121
+
122
+ for namespace_filename in os.listdir(directory):
123
+ if namespace_filename == NAMESPACE_PADDING_FILE:
124
+ continue
125
+ if namespace_filename.startswith("."):
126
+ continue
127
+ namespace = namespace_filename.replace('.txt', '')
128
+ if any(namespace_match(pattern, namespace) for pattern in non_padded_namespaces):
129
+ is_padded = False
130
+ else:
131
+ is_padded = True
132
+ filename = os.path.join(directory, namespace_filename)
133
+ vocab.set_from_file(filename, is_padded, namespace=namespace)
134
+
135
+ return vocab
136
+
137
+ def set_from_file(self,
138
+ filename: str,
139
+ is_padded: bool = True,
140
+ oov_token: str = DEFAULT_OOV_TOKEN,
141
+ namespace: str = "tokens"
142
+ ):
143
+ if is_padded:
144
+ self._token_to_index[namespace] = {self._padding_token: 0}
145
+ self._index_to_token[namespace] = {0: self._padding_token}
146
+ else:
147
+ self._token_to_index[namespace] = {}
148
+ self._index_to_token[namespace] = {}
149
+
150
+ with open(filename, 'r', encoding='utf-8') as f:
151
+ index = 1 if is_padded else 0
152
+ for row in f:
153
+ token = str(row).strip()
154
+ if token == oov_token:
155
+ token = self._oov_token
156
+ self._token_to_index[namespace][token] = index
157
+ self._index_to_token[namespace][index] = token
158
+ index += 1
159
+
160
+ def convert_tokens_to_ids(self, tokens: List[str], namespace: str = "tokens"):
161
+ result = list()
162
+ for token in tokens:
163
+ idx = self._token_to_index[namespace].get(token)
164
+ if idx is None:
165
+ idx = self._token_to_index[namespace][self._oov_token]
166
+ result.append(idx)
167
+ return result
168
+
169
+ def convert_ids_to_tokens(self, ids: List[int], namespace: str = "tokens"):
170
+ result = list()
171
+ for idx in ids:
172
+ idx = self._index_to_token[namespace][idx]
173
+ result.append(idx)
174
+ return result
175
+
176
+ def pad_or_truncate_ids_by_max_length(self, ids: List[int], max_length: int, namespace: str = "tokens"):
177
+ pad_idx = self._token_to_index[namespace][self._padding_token]
178
+
179
+ length = len(ids)
180
+ if length > max_length:
181
+ result = ids[:max_length]
182
+ else:
183
+ result = ids + [pad_idx] * (max_length - length)
184
+ return result
185
+
186
+
187
+ def demo1():
188
+ import jieba
189
+
190
+ vocabulary = Vocabulary()
191
+ vocabulary.add_token_to_namespace('白天', 'tokens')
192
+ vocabulary.add_token_to_namespace('晚上', 'tokens')
193
+
194
+ text = '不是在白天, 就是在晚上'
195
+ tokens = jieba.lcut(text)
196
+
197
+ print(tokens)
198
+
199
+ ids = vocabulary.convert_tokens_to_ids(tokens)
200
+ print(ids)
201
+
202
+ padded_idx = vocabulary.pad_or_truncate_ids_by_max_length(ids, 10)
203
+ print(padded_idx)
204
+
205
+ tokens = vocabulary.convert_ids_to_tokens(padded_idx)
206
+ print(tokens)
207
+ return
208
+
209
+
210
+ if __name__ == '__main__':
211
+ demo1()