ZhongYing
commited on
Commit
·
503ec99
1
Parent(s):
6ffe62c
first commit
Browse files- .gitignore +1 -0
- configs/__init__.py +18 -0
- configs/config.py +47 -0
- configs/config.yml +49 -0
- dataset.py +263 -0
- featurizers/__init__.py +0 -0
- featurizers/gammatone.py +233 -0
- featurizers/speech_featurizers.py +453 -0
- librosa_mel_filter.csv +0 -0
- models/__init__.py +0 -0
- models/layers/__init__.py +0 -0
- models/layers/attention.py +35 -0
- models/model.py +98 -0
- optimizers/__init.py +0 -0
- optimizers/schedules.py +81 -0
- predict_by_pb.py +30 -0
- predict_by_weights.py +36 -0
- train.py +275 -0
- util/__init__.py +0 -0
- util/utils.py +78 -0
- vocab/__init__.py +0 -0
- vocab/vocab.py +11 -0
- vocab/vocab.txt +14 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
**/__pycache
|
configs/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import yaml
|
3 |
+
|
4 |
+
|
5 |
+
def load_yaml(path):
|
6 |
+
loader = yaml.SafeLoader
|
7 |
+
loader.add_implicit_resolver(
|
8 |
+
u'tag:yaml.org,2002:float',
|
9 |
+
re.compile(u'''^(?:
|
10 |
+
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
11 |
+
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
12 |
+
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
13 |
+
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|
14 |
+
|[-+]?\\.(?:inf|Inf|INF)
|
15 |
+
|\\.(?:nan|NaN|NAN))$''', re.X),
|
16 |
+
list(u'-+0123456789.'))
|
17 |
+
with open(path, "r", encoding="utf-8") as file:
|
18 |
+
return yaml.load(file, Loader=loader)
|
configs/config.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 by zhongying
|
2 |
+
#
|
3 |
+
|
4 |
+
|
5 |
+
from . import load_yaml
|
6 |
+
from util.utils import preprocess_paths
|
7 |
+
|
8 |
+
|
9 |
+
class Config:
|
10 |
+
""" User configs class for training, testing or infering """
|
11 |
+
|
12 |
+
def __init__(self, path: str):
|
13 |
+
print('configs file path:', path)
|
14 |
+
config = load_yaml(preprocess_paths(path))
|
15 |
+
self.speech_config = config.get("speech_config", {})
|
16 |
+
self.model_config = config.get("model_config", {})
|
17 |
+
self.dataset_config = config.get("dataset_config", {})
|
18 |
+
self.optimizer_config = config.get("optimizer_config", {})
|
19 |
+
self.running_config = config.get("running_config", {})
|
20 |
+
|
21 |
+
def print(self):
|
22 |
+
print('==================================================')
|
23 |
+
print('speech configs:', self.speech_config)
|
24 |
+
print('--------------------------------------------------')
|
25 |
+
print('model configs:', self.model_config)
|
26 |
+
print('--------------------------------------------------')
|
27 |
+
print('dataset configs:', self.dataset_config)
|
28 |
+
print('--------------------------------------------------')
|
29 |
+
print('optimizer configs', self.optimizer_config)
|
30 |
+
print('--------------------------------------------------')
|
31 |
+
print('running configs:', self.running_config)
|
32 |
+
print('==================================================')
|
33 |
+
|
34 |
+
def toString(self):
|
35 |
+
string = ''
|
36 |
+
string += '#==================================================' + '\n'
|
37 |
+
string += '#speech config: ' + str(self.speech_config) + '\n'
|
38 |
+
string += '#--------------------------------------------------' + '\n'
|
39 |
+
string += '#model config: ' + str(self.model_config) + '\n'
|
40 |
+
string += '#--------------------------------------------------' + '\n'
|
41 |
+
string += '#dataset config: ' + str(self.dataset_config) + '\n'
|
42 |
+
string += '#--------------------------------------------------' + '\n'
|
43 |
+
string += '#optimizer config: ' + str(self.optimizer_config) + '\n'
|
44 |
+
string += '#--------------------------------------------------' + '\n'
|
45 |
+
string += '#running config: ' + str(self.running_config) + '\n'
|
46 |
+
string += '#==================================================' + '\n'
|
47 |
+
return string
|
configs/config.yml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
speech_config:
|
2 |
+
sample_rate: 16000
|
3 |
+
frame_ms: 25
|
4 |
+
stride_ms: 10
|
5 |
+
num_feature_bins: 80
|
6 |
+
feature_type: log_mel_spectrogram
|
7 |
+
preemphasis: 0.97
|
8 |
+
normalize_signal: True
|
9 |
+
normalize_feature: True
|
10 |
+
normalize_per_feature: False
|
11 |
+
|
12 |
+
model_config:
|
13 |
+
name: acrnn
|
14 |
+
d_model: 64
|
15 |
+
filters: [32,64,64]
|
16 |
+
kernel_size: [[11,5],[11,5],[11,5]]
|
17 |
+
rnn_cell: 256
|
18 |
+
seq_mask: True
|
19 |
+
|
20 |
+
dataset_config:
|
21 |
+
vocabulary: vocab/vocab.txt
|
22 |
+
data_path: ./data/wavs/
|
23 |
+
corpus_name: ./data/demo_txt/demo
|
24 |
+
file_nums: 1
|
25 |
+
max_audio_length: 2000
|
26 |
+
shuffle_size: 1200
|
27 |
+
data_length: None
|
28 |
+
suffix: .txt
|
29 |
+
load_type: txt
|
30 |
+
train: train
|
31 |
+
dev: dev
|
32 |
+
test: test
|
33 |
+
|
34 |
+
optimizer_config:
|
35 |
+
init_steps: 0
|
36 |
+
warmup_steps: 10000
|
37 |
+
max_lr: 1e-4
|
38 |
+
beta1: 0.9
|
39 |
+
beta2: 0.999
|
40 |
+
epsilon: 1e-9
|
41 |
+
|
42 |
+
running_config:
|
43 |
+
prefetch: False
|
44 |
+
load_weights: ./saved_weights/20230228-084356/last/model
|
45 |
+
num_epochs: 100
|
46 |
+
batch_size: 1
|
47 |
+
train_steps: 50
|
48 |
+
dev_steps: 10
|
49 |
+
test_steps: 10
|
dataset.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from featurizers.speech_featurizers import SpeechFeaturizer
|
2 |
+
from configs.config import Config
|
3 |
+
from random import shuffle
|
4 |
+
import numpy as np
|
5 |
+
from vocab.vocab import Vocab
|
6 |
+
import os
|
7 |
+
import math
|
8 |
+
import librosa
|
9 |
+
import tensorflow as tf
|
10 |
+
|
11 |
+
|
12 |
+
def wav_padding(wav_data_lst, wav_max_len, fbank_dim):
|
13 |
+
wav_lens = [len(data) for data in wav_data_lst]
|
14 |
+
# input wav from 1200 frames down sample 8 times to 150 frames
|
15 |
+
wav_lens = [math.ceil(x/8) for x in wav_lens]
|
16 |
+
wav_lens = np.array(wav_lens)
|
17 |
+
new_wav_data_lst = np.zeros((len(wav_data_lst), wav_max_len, fbank_dim))
|
18 |
+
for i in range(len(wav_data_lst)):
|
19 |
+
new_wav_data_lst[i, :wav_data_lst[i].shape[0], :] = wav_data_lst[i]
|
20 |
+
return new_wav_data_lst, wav_lens
|
21 |
+
|
22 |
+
|
23 |
+
class DatDataSet:
|
24 |
+
def __init__(self,
|
25 |
+
batch_size,
|
26 |
+
data_type,
|
27 |
+
vocab: Vocab,
|
28 |
+
speech_featurizer: SpeechFeaturizer,
|
29 |
+
config: Config):
|
30 |
+
self.batch_size = batch_size
|
31 |
+
self.data_type = data_type
|
32 |
+
self.vocab = vocab
|
33 |
+
self.data_path =config.dataset_config['data_path']
|
34 |
+
self.corpus_name = config.dataset_config['corpus_name']
|
35 |
+
self.fbank_dim = config.speech_config['num_feature_bins']
|
36 |
+
self.max_audio_length =config.dataset_config['max_audio_length']
|
37 |
+
self.mel_banks = config.speech_config['num_feature_bins']
|
38 |
+
self.file_nums = config.dataset_config['file_nums']
|
39 |
+
self.language_classes = config.running_config['language_classes']
|
40 |
+
self.suffix = config.dataset_config['suffix']
|
41 |
+
self.READ_BUFFER_SIZE = 2 * 1024 * 1024 * 1024
|
42 |
+
self.shuffle = True
|
43 |
+
self.blank = 0
|
44 |
+
self.source_init()
|
45 |
+
|
46 |
+
|
47 |
+
def source_init(self):
|
48 |
+
self.dat_file_list, self.txt_file_list = self.get_dat_txt_list(self.data_type)
|
49 |
+
print('>>', self.data_type, 'load dat files:', len(self.dat_file_list))
|
50 |
+
print('>>', self.data_type, 'load txt files:', len(self.txt_file_list))
|
51 |
+
max_binary_file_size = max([os.path.getsize(dat) for dat in self.dat_file_list])
|
52 |
+
print('>> max binary file size:', max_binary_file_size)
|
53 |
+
# alloc a huge memory block
|
54 |
+
self.feature_binary = np.zeros(max_binary_file_size // 4 + 1, np.float32)
|
55 |
+
|
56 |
+
|
57 |
+
def get_dat_txt_list(self, dir_name):
|
58 |
+
corpus_dir = self.data_path+'/'+self.corpus_name + '/'
|
59 |
+
print('!!', corpus_dir)
|
60 |
+
file_lst = os.listdir(corpus_dir)
|
61 |
+
txt_file_lst = []
|
62 |
+
dat_file_lst = []
|
63 |
+
|
64 |
+
for align_file in file_lst:
|
65 |
+
if align_file.endswith(self.suffix):
|
66 |
+
file_name = align_file[:-len(self.suffix)]
|
67 |
+
dat_file = file_name + '.dat'
|
68 |
+
if dir_name in file_name:
|
69 |
+
# if dir_name in ['dev', 'test']:
|
70 |
+
# dat_file = dat_file.replace(dir_name, 'train')
|
71 |
+
dat_file_lst.append(corpus_dir + dat_file)
|
72 |
+
txt_file_lst.append(corpus_dir + align_file)
|
73 |
+
print('*********',dir_name, txt_file_lst, dat_file_lst)
|
74 |
+
return dat_file_lst, txt_file_lst
|
75 |
+
|
76 |
+
|
77 |
+
def load_dat_file(self, dat_file_path):
|
78 |
+
f = open(dat_file_path, 'rb')
|
79 |
+
pos = 0
|
80 |
+
buf = f.read(self.READ_BUFFER_SIZE)
|
81 |
+
while len(buf) > 0:
|
82 |
+
nbuf = np.frombuffer(buf, np.float32)
|
83 |
+
self.feature_binary[pos: pos + len(nbuf)] = nbuf
|
84 |
+
pos += len(nbuf)
|
85 |
+
buf = f.read(self.READ_BUFFER_SIZE)
|
86 |
+
|
87 |
+
|
88 |
+
def get_batch(self):
|
89 |
+
while 1:
|
90 |
+
shuffle_did_list = [i for i in range(len(self.dat_file_list))]
|
91 |
+
if self.shuffle:
|
92 |
+
shuffle(shuffle_did_list)
|
93 |
+
for did in shuffle_did_list:
|
94 |
+
wav_lst = []
|
95 |
+
label_lst = []
|
96 |
+
self.load_dat_file(self.dat_file_list[did])
|
97 |
+
txt_file = open(self.txt_file_list[did], 'r', encoding='utf8')
|
98 |
+
utt_lines = txt_file.readlines()
|
99 |
+
txt_lines = utt_lines
|
100 |
+
if self.shuffle:
|
101 |
+
shuffle(txt_lines)
|
102 |
+
# sort lines by wav len
|
103 |
+
# txt_lines = sorted(
|
104 |
+
# txt_lines,
|
105 |
+
# key=lambda line: int(line.split('\t')[0].split(':')[2]) - int(line.split('\t')[0].split(':')[1]),
|
106 |
+
# reverse=False)
|
107 |
+
for line in txt_lines:
|
108 |
+
wav_file, label = line.split('\t')
|
109 |
+
wav_lst.append(wav_file)
|
110 |
+
label_lst.append(label.strip('\n'))
|
111 |
+
shuffle_list = [i for i in range(len(wav_lst) // self.batch_size)]
|
112 |
+
if self.shuffle:
|
113 |
+
shuffle(shuffle_list)
|
114 |
+
for i in shuffle_list:
|
115 |
+
begin = i * self.batch_size
|
116 |
+
end = begin + self.batch_size
|
117 |
+
sub_list = list(range(begin, end, 1))
|
118 |
+
# label batch
|
119 |
+
label_data_lst = [label_lst[index] for index in sub_list]
|
120 |
+
prediction = np.array(
|
121 |
+
[self.vocab.token_list.index(line) for
|
122 |
+
line in label_data_lst],
|
123 |
+
dtype=np.int32)
|
124 |
+
|
125 |
+
feature_lst = []
|
126 |
+
wav_path = []
|
127 |
+
get_next_batch = False
|
128 |
+
for index in sub_list:
|
129 |
+
# data_aishell/wav/test/S0764/BAC009S0764W0121.wav:0:33680 chinese
|
130 |
+
_, start, end = wav_lst[index].split(':')
|
131 |
+
feature = self.feature_binary[int(start): int(end)]
|
132 |
+
feature = np.reshape(feature, (-1, 80))
|
133 |
+
feature = feature[:self.max_audio_length, :]
|
134 |
+
feature_lst.append(feature)
|
135 |
+
wav_path.append(wav_lst[index])
|
136 |
+
|
137 |
+
if get_next_batch:
|
138 |
+
continue
|
139 |
+
features, input_length = wav_padding(feature_lst, self.max_audio_length, self.fbank_dim)
|
140 |
+
|
141 |
+
yield features, input_length, prediction
|
142 |
+
|
143 |
+
|
144 |
+
class TxtDataSet:
|
145 |
+
def __init__(self,
|
146 |
+
batch_size,
|
147 |
+
data_type,
|
148 |
+
vocab,
|
149 |
+
speech_featurizer: SpeechFeaturizer,
|
150 |
+
config: Config
|
151 |
+
):
|
152 |
+
self.batch_size = batch_size
|
153 |
+
self.data_type = data_type
|
154 |
+
self.vocab = vocab
|
155 |
+
self.feature_extracter = speech_featurizer
|
156 |
+
self.data_path = config.dataset_config['data_path']
|
157 |
+
self.corpus_name = config.dataset_config['corpus_name']
|
158 |
+
self.fbank_dim = config.speech_config['num_feature_bins']
|
159 |
+
self.max_audio_length =config.dataset_config['max_audio_length']
|
160 |
+
self.mel_banks = config.speech_config['num_feature_bins']
|
161 |
+
self.file_nums = config.dataset_config['file_nums']
|
162 |
+
self.data_length = config.dataset_config['data_length']
|
163 |
+
self.shuffle = True
|
164 |
+
self.sentence_list = []
|
165 |
+
self.wav_lst = []
|
166 |
+
self.label_lst = []
|
167 |
+
self.max_sentence_length = 0
|
168 |
+
self.source_init()
|
169 |
+
|
170 |
+
def source_init(self):
|
171 |
+
read_files = []
|
172 |
+
if self.data_type == 'train':
|
173 |
+
read_files.append(self.corpus_name + '_train.txt')
|
174 |
+
elif self.data_type == 'dev':
|
175 |
+
read_files.append(self.corpus_name + '_dev.txt')
|
176 |
+
elif self.data_type == 'test':
|
177 |
+
read_files.append(self.corpus_name + '_test.txt')
|
178 |
+
print('data type:{} \n files:{}'.format(self.data_type, read_files))
|
179 |
+
total_lines = 0
|
180 |
+
for sub_file in read_files:
|
181 |
+
with open(sub_file, 'r', encoding='utf8') as f:
|
182 |
+
for line in f:
|
183 |
+
wav_file, label = line.split(' ', 1)
|
184 |
+
label = label.strip('\n').split()
|
185 |
+
|
186 |
+
self.label_lst.append(label)
|
187 |
+
self.wav_lst.append(wav_file)
|
188 |
+
total_lines += 1
|
189 |
+
if self.data_length:
|
190 |
+
if total_lines == self.data_length:
|
191 |
+
break
|
192 |
+
if total_lines % 10000 == 0:
|
193 |
+
print('\rload', total_lines, end='', flush=True)
|
194 |
+
|
195 |
+
if not self.data_length:
|
196 |
+
self.wav_lst = self.wav_lst[:self.data_length]
|
197 |
+
self.label_lst = self.label_lst[:self.data_length]
|
198 |
+
print('number of', self.data_type, 'data:', len(self.wav_lst))
|
199 |
+
|
200 |
+
|
201 |
+
def get_batch(self):
|
202 |
+
shuffle_list = [i for i in range(len(self.wav_lst))]
|
203 |
+
while 1:
|
204 |
+
if self.shuffle:
|
205 |
+
shuffle(shuffle_list)
|
206 |
+
for i in range(len(self.wav_lst) // self.batch_size):
|
207 |
+
begin = i * self.batch_size
|
208 |
+
end = begin + self.batch_size
|
209 |
+
sub_list = shuffle_list[begin:end]
|
210 |
+
|
211 |
+
label_data_lst = [self.label_lst[index] for index in sub_list]
|
212 |
+
prediction = np.array(
|
213 |
+
[self.vocab.token_list.index(line[0]) for
|
214 |
+
line in label_data_lst],
|
215 |
+
dtype=np.int32)
|
216 |
+
feature_lst = []
|
217 |
+
wav_path = []
|
218 |
+
get_next_batch = False
|
219 |
+
for index in sub_list:
|
220 |
+
# start = time.time()
|
221 |
+
audio, _ = librosa.load(self.data_path + self.wav_lst[index], sr=16000)
|
222 |
+
if len(audio) == 0:
|
223 |
+
get_next_batch = True
|
224 |
+
break
|
225 |
+
feature = self.feature_extracter.extract(audio)
|
226 |
+
|
227 |
+
feature_lst.append(feature)
|
228 |
+
wav_path.append(self.wav_lst[index])
|
229 |
+
|
230 |
+
if get_next_batch:
|
231 |
+
continue # get next batch
|
232 |
+
|
233 |
+
features, input_length = wav_padding(feature_lst, self.max_audio_length, self.fbank_dim)
|
234 |
+
|
235 |
+
yield features,input_length, prediction
|
236 |
+
|
237 |
+
|
238 |
+
def create_dataset(batch_size, load_type, data_type, speech_featurizer, config, vocab):
|
239 |
+
"""
|
240 |
+
batch_size: global batch size
|
241 |
+
data_type: the type of lode data, supports type: txt, dat()
|
242 |
+
|
243 |
+
"""
|
244 |
+
if load_type == 'dat':
|
245 |
+
dataset = DatDataSet(batch_size, data_type, vocab, speech_featurizer, config)
|
246 |
+
dataset = tf.data.Dataset.from_generator(dataset.get_batch,
|
247 |
+
output_types=(tf.float32, tf.int32, tf.int32),
|
248 |
+
output_shapes=([None, None, config.speech_config['num_feature_bins']],
|
249 |
+
[None], [None]))
|
250 |
+
elif load_type == 'txt':
|
251 |
+
dataset = TxtDataSet(batch_size, data_type, vocab, speech_featurizer, config)
|
252 |
+
dataset = tf.data.Dataset.from_generator(dataset.get_batch,
|
253 |
+
output_types=(tf.float32, tf.int32, tf.int32),
|
254 |
+
output_shapes=([None, None, config.speech_config['num_feature_bins']],
|
255 |
+
[None], [None]))
|
256 |
+
else:
|
257 |
+
print('load_type must be dat or txt!!')
|
258 |
+
return
|
259 |
+
|
260 |
+
options = tf.data.Options()
|
261 |
+
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA.DATA
|
262 |
+
dataset = dataset.with_options(options)
|
263 |
+
return dataset
|
featurizers/__init__.py
ADDED
File without changes
|
featurizers/gammatone.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" This code is inspired from https://github.com/detly/gammatone """
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import tensorflow as tf
|
5 |
+
|
6 |
+
from util.utils import shape_list
|
7 |
+
|
8 |
+
pi = tf.constant(np.pi, dtype=tf.complex64)
|
9 |
+
|
10 |
+
DEFAULT_FILTER_NUM = 100
|
11 |
+
DEFAULT_LOW_FREQ = 100
|
12 |
+
DEFAULT_HIGH_FREQ = 44100 / 4
|
13 |
+
|
14 |
+
|
15 |
+
def fft_weights(
|
16 |
+
nfft,
|
17 |
+
fs,
|
18 |
+
nfilts,
|
19 |
+
width,
|
20 |
+
fmin,
|
21 |
+
fmax,
|
22 |
+
maxlen):
|
23 |
+
"""
|
24 |
+
:param nfft: the source FFT size
|
25 |
+
:param sr: sampling rate (Hz)
|
26 |
+
:param nfilts: the number of output bands required (default 64)
|
27 |
+
:param width: the constant width of each band in Bark (default 1)
|
28 |
+
:param fmin: lower limit of frequencies (Hz)
|
29 |
+
:param fmax: upper limit of frequencies (Hz)
|
30 |
+
:param maxlen: number of bins to truncate the rows to
|
31 |
+
|
32 |
+
:return: a tuple `weights`, `gain` with the calculated weight matrices and
|
33 |
+
gain vectors
|
34 |
+
|
35 |
+
Generate a matrix of weights to combine FFT bins into Gammatone bins.
|
36 |
+
|
37 |
+
Note about `maxlen` parameter: While wts has nfft columns, the second half
|
38 |
+
are all zero. Hence, aud spectrum is::
|
39 |
+
|
40 |
+
fft2gammatonemx(nfft,sr)*abs(fft(xincols,nfft))
|
41 |
+
|
42 |
+
`maxlen` truncates the rows to this many bins.
|
43 |
+
|
44 |
+
| (c) 2004-2009 Dan Ellis dpwe@ee.columbia.edu based on rastamat/audspec.m
|
45 |
+
| (c) 2012 Jason Heeris (Python implementation)
|
46 |
+
"""
|
47 |
+
ucirc = tf.exp(1j * 2 * pi * tf.cast(tf.range(0, nfft / 2 + 1),
|
48 |
+
tf.complex64) / nfft)[None, ...]
|
49 |
+
|
50 |
+
# Common ERB filter code factored out
|
51 |
+
cf_array = erb_space(fmin, fmax, nfilts)[::-1]
|
52 |
+
|
53 |
+
_, A11, A12, A13, A14, _, _, _, B2, gain = make_erb_filters(fs, cf_array, width)
|
54 |
+
|
55 |
+
A11, A12, A13, A14 = A11[..., None], A12[..., None], A13[..., None], A14[..., None]
|
56 |
+
|
57 |
+
r = tf.cast(tf.sqrt(B2), tf.complex64)
|
58 |
+
theta = 2 * pi * cf_array / fs
|
59 |
+
pole = (r * tf.exp(1j * theta))[..., None]
|
60 |
+
|
61 |
+
GTord = 4
|
62 |
+
|
63 |
+
weights = (
|
64 |
+
tf.abs(ucirc + A11 * fs) * tf.abs(ucirc + A12 * fs)
|
65 |
+
* tf.abs(ucirc + A13 * fs) * tf.abs(ucirc + A14 * fs)
|
66 |
+
* tf.abs(fs * (pole - ucirc) * (tf.math.conj(pole) - ucirc)) ** (-GTord)
|
67 |
+
/ tf.cast(gain[..., None], tf.float32)
|
68 |
+
)
|
69 |
+
|
70 |
+
weights = tf.pad(weights, [[0, 0], [0, nfft - shape_list(weights)[-1]]])
|
71 |
+
|
72 |
+
weights = weights[:, 0:int(maxlen)]
|
73 |
+
|
74 |
+
return tf.transpose(weights, perm=[1, 0])
|
75 |
+
|
76 |
+
|
77 |
+
def erb_point(low_freq, high_freq, fraction):
|
78 |
+
"""
|
79 |
+
Calculates a single point on an ERB scale between ``low_freq`` and
|
80 |
+
``high_freq``, determined by ``fraction``. When ``fraction`` is ``1``,
|
81 |
+
``low_freq`` will be returned. When ``fraction`` is ``0``, ``high_freq``
|
82 |
+
will be returned.
|
83 |
+
|
84 |
+
``fraction`` can actually be outside the range ``[0, 1]``, which in general
|
85 |
+
isn't very meaningful, but might be useful when ``fraction`` is rounded a
|
86 |
+
little above or below ``[0, 1]`` (eg. for plot axis labels).
|
87 |
+
"""
|
88 |
+
# Change the following three parameters if you wish to use a different ERB
|
89 |
+
# scale. Must change in MakeERBCoeffs too.
|
90 |
+
# TODO: Factor these parameters out
|
91 |
+
ear_q = 9.26449 # Glasberg and Moore Parameters
|
92 |
+
min_bw = 24.7
|
93 |
+
|
94 |
+
# All of the following expressions are derived in Apple TR #35, "An
|
95 |
+
# Efficient Implementation of the Patterson-Holdsworth Cochlear Filter
|
96 |
+
# Bank." See pages 33-34.
|
97 |
+
erb_point = (
|
98 |
+
-ear_q * min_bw
|
99 |
+
+ tf.exp(
|
100 |
+
fraction * (
|
101 |
+
-tf.math.log(high_freq + ear_q * min_bw)
|
102 |
+
+ tf.math.log(low_freq + ear_q * min_bw)
|
103 |
+
)
|
104 |
+
) *
|
105 |
+
(high_freq + ear_q * min_bw)
|
106 |
+
)
|
107 |
+
|
108 |
+
return tf.cast(erb_point, tf.complex64)
|
109 |
+
|
110 |
+
|
111 |
+
def erb_space(
|
112 |
+
low_freq=DEFAULT_LOW_FREQ,
|
113 |
+
high_freq=DEFAULT_HIGH_FREQ,
|
114 |
+
num=DEFAULT_FILTER_NUM):
|
115 |
+
"""
|
116 |
+
This function computes an array of ``num`` frequencies uniformly spaced
|
117 |
+
between ``high_freq`` and ``low_freq`` on an ERB scale.
|
118 |
+
|
119 |
+
For a definition of ERB, see Moore, B. C. J., and Glasberg, B. R. (1983).
|
120 |
+
"Suggested formulae for calculating auditory-filter bandwidths and
|
121 |
+
excitation patterns," J. Acoust. Soc. Am. 74, 750-753.
|
122 |
+
"""
|
123 |
+
return erb_point(
|
124 |
+
low_freq,
|
125 |
+
high_freq,
|
126 |
+
tf.range(1, num + 1, dtype=tf.float32) / num
|
127 |
+
)
|
128 |
+
|
129 |
+
|
130 |
+
def make_erb_filters(fs, centre_freqs, width=1.0):
|
131 |
+
"""
|
132 |
+
This function computes the filter coefficients for a bank of
|
133 |
+
Gammatone filters. These filters were defined by Patterson and Holdworth for
|
134 |
+
simulating the cochlea.
|
135 |
+
|
136 |
+
The result is returned as a :class:`ERBCoeffArray`. Each row of the
|
137 |
+
filter arrays contains the coefficients for four second order filters. The
|
138 |
+
transfer function for these four filters share the same denominator (poles)
|
139 |
+
but have different numerators (zeros). All of these coefficients are
|
140 |
+
assembled into one vector that the ERBFilterBank can take apart to implement
|
141 |
+
the filter.
|
142 |
+
|
143 |
+
The filter bank contains "numChannels" channels that extend from
|
144 |
+
half the sampling rate (fs) to "lowFreq". Alternatively, if the numChannels
|
145 |
+
input argument is a vector, then the values of this vector are taken to be
|
146 |
+
the center frequency of each desired filter. (The lowFreq argument is
|
147 |
+
ignored in this case.)
|
148 |
+
|
149 |
+
Note this implementation fixes a problem in the original code by
|
150 |
+
computing four separate second order filters. This avoids a big problem with
|
151 |
+
round off errors in cases of very small cfs (100Hz) and large sample rates
|
152 |
+
(44kHz). The problem is caused by roundoff error when a number of poles are
|
153 |
+
combined, all very close to the unit circle. Small errors in the eigth order
|
154 |
+
coefficient, are multiplied when the eigth root is taken to give the pole
|
155 |
+
location. These small errors lead to poles outside the unit circle and
|
156 |
+
instability. Thanks to Julius Smith for leading me to the proper
|
157 |
+
explanation.
|
158 |
+
|
159 |
+
Execute the following code to evaluate the frequency response of a 10
|
160 |
+
channel filterbank::
|
161 |
+
|
162 |
+
fcoefs = MakeERBFilters(16000,10,100);
|
163 |
+
y = ERBFilterBank([1 zeros(1,511)], fcoefs);
|
164 |
+
resp = 20*log10(abs(fft(y')));
|
165 |
+
freqScale = (0:511)/512*16000;
|
166 |
+
semilogx(freqScale(1:255),resp(1:255,:));
|
167 |
+
axis([100 16000 -60 0])
|
168 |
+
xlabel('Frequency (Hz)'); ylabel('Filter Response (dB)');
|
169 |
+
|
170 |
+
| Rewritten by Malcolm Slaney@Interval. June 11, 1998.
|
171 |
+
| (c) 1998 Interval Research Corporation
|
172 |
+
|
|
173 |
+
| (c) 2012 Jason Heeris (Python implementation)
|
174 |
+
"""
|
175 |
+
T = 1 / fs
|
176 |
+
# Change the followFreqing three parameters if you wish to use a different
|
177 |
+
# ERB scale. Must change in ERBSpace too.
|
178 |
+
# TODO: factor these out
|
179 |
+
ear_q = 9.26449 # Glasberg and Moore Parameters
|
180 |
+
min_bw = 24.7
|
181 |
+
order = 1
|
182 |
+
|
183 |
+
erb = width * ((centre_freqs / ear_q) ** order + min_bw ** order) ** (1 / order)
|
184 |
+
B = 1.019 * 2 * pi * erb
|
185 |
+
|
186 |
+
arg = 2 * centre_freqs * pi * T
|
187 |
+
vec = tf.exp(2j * arg)
|
188 |
+
|
189 |
+
A0 = T
|
190 |
+
A2 = 0
|
191 |
+
B0 = 1
|
192 |
+
B1 = -2 * tf.cos(arg) / tf.exp(B * T)
|
193 |
+
B2 = tf.exp(-2 * B * T)
|
194 |
+
|
195 |
+
rt_pos = tf.cast(tf.sqrt(3 + 2 ** 1.5), tf.complex64)
|
196 |
+
rt_neg = tf.cast(tf.sqrt(3 - 2 ** 1.5), tf.complex64)
|
197 |
+
|
198 |
+
common = -T * tf.exp(-(B * T))
|
199 |
+
|
200 |
+
# TODO: This could be simplified to a matrix calculation involving the
|
201 |
+
# constant first term and the alternating rt_pos/rt_neg and +/-1 second
|
202 |
+
# terms
|
203 |
+
k11 = tf.cos(arg) + rt_pos * tf.sin(arg)
|
204 |
+
k12 = tf.cos(arg) - rt_pos * tf.sin(arg)
|
205 |
+
k13 = tf.cos(arg) + rt_neg * tf.sin(arg)
|
206 |
+
k14 = tf.cos(arg) - rt_neg * tf.sin(arg)
|
207 |
+
|
208 |
+
A11 = common * k11
|
209 |
+
A12 = common * k12
|
210 |
+
A13 = common * k13
|
211 |
+
A14 = common * k14
|
212 |
+
|
213 |
+
gain_arg = tf.exp(1j * arg - B * T)
|
214 |
+
|
215 |
+
gain = tf.cast(tf.abs(
|
216 |
+
(vec - gain_arg * k11)
|
217 |
+
* (vec - gain_arg * k12)
|
218 |
+
* (vec - gain_arg * k13)
|
219 |
+
* (vec - gain_arg * k14)
|
220 |
+
* (T * tf.exp(B * T)
|
221 |
+
/ (-1 / tf.exp(B * T) + 1 + vec * (1 - tf.exp(B * T)))
|
222 |
+
)**4
|
223 |
+
), tf.complex64)
|
224 |
+
|
225 |
+
allfilts = tf.ones_like(centre_freqs, dtype=tf.complex64)
|
226 |
+
|
227 |
+
fcoefs = tf.stack([
|
228 |
+
A0 * allfilts, A11, A12, A13, A14, A2 * allfilts,
|
229 |
+
B0 * allfilts, B1, B2,
|
230 |
+
gain
|
231 |
+
], axis=1)
|
232 |
+
|
233 |
+
return tf.transpose(fcoefs, perm=[1, 0])
|
featurizers/speech_featurizers.py
ADDED
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import io
|
3 |
+
import abc
|
4 |
+
import six
|
5 |
+
import numpy as np
|
6 |
+
import librosa
|
7 |
+
import soundfile as sf
|
8 |
+
import tensorflow as tf
|
9 |
+
|
10 |
+
from util.utils import log10
|
11 |
+
from .gammatone import fft_weights
|
12 |
+
|
13 |
+
|
14 |
+
def read_raw_audio(audio, sample_rate=16000):
|
15 |
+
if isinstance(audio, str):
|
16 |
+
wave, _ = librosa.load(os.path.expanduser(audio), sr=sample_rate)
|
17 |
+
elif isinstance(audio, bytes):
|
18 |
+
wave, sr = sf.read(io.BytesIO(audio))
|
19 |
+
wave = np.asfortranarray(wave)
|
20 |
+
if sr != sample_rate:
|
21 |
+
wave = librosa.resample(wave, sr, sample_rate)
|
22 |
+
elif isinstance(audio, np.ndarray):
|
23 |
+
return audio
|
24 |
+
else:
|
25 |
+
raise ValueError("input audio must be either a path or bytes")
|
26 |
+
return wave
|
27 |
+
|
28 |
+
|
29 |
+
def slice_signal(signal, window_size, stride=0.5) -> np.ndarray:
|
30 |
+
""" Return windows of the given signal by sweeping in stride fractions of window """
|
31 |
+
assert signal.ndim == 1, signal.ndim
|
32 |
+
n_samples = signal.shape[0]
|
33 |
+
offset = int(window_size * stride)
|
34 |
+
slices = []
|
35 |
+
for beg_i, end_i in zip(range(0, n_samples, offset),
|
36 |
+
range(window_size, n_samples + offset,
|
37 |
+
offset)):
|
38 |
+
slice_ = signal[beg_i:end_i]
|
39 |
+
if slice_.shape[0] < window_size:
|
40 |
+
slice_ = np.pad(
|
41 |
+
slice_, (0, window_size - slice_.shape[0]), 'constant', constant_values=0.0)
|
42 |
+
if slice_.shape[0] == window_size:
|
43 |
+
slices.append(slice_)
|
44 |
+
return np.array(slices, dtype=np.float32)
|
45 |
+
|
46 |
+
|
47 |
+
def tf_merge_slices(slices: tf.Tensor) -> tf.Tensor:
|
48 |
+
# slices shape = [batch, window_size]
|
49 |
+
return tf.keras.backend.flatten(slices) # return shape = [-1, ]
|
50 |
+
|
51 |
+
|
52 |
+
def merge_slices(slices: np.ndarray) -> np.ndarray:
|
53 |
+
# slices shape = [batch, window_size]
|
54 |
+
return np.reshape(slices, [-1])
|
55 |
+
|
56 |
+
|
57 |
+
def normalize_audio_feature(audio_feature: np.ndarray, per_feature=False):
|
58 |
+
""" Mean and variance normalization """
|
59 |
+
axis = 0 if per_feature else None
|
60 |
+
mean = np.mean(audio_feature, axis=axis)
|
61 |
+
std_dev = np.std(audio_feature, axis=axis) + 1e-9
|
62 |
+
normalized = (audio_feature - mean) / std_dev
|
63 |
+
return normalized
|
64 |
+
|
65 |
+
|
66 |
+
def tf_normalize_audio_features(audio_feature: tf.Tensor, per_feature=False):
|
67 |
+
"""
|
68 |
+
TF Mean and variance features normalization
|
69 |
+
Args:
|
70 |
+
audio_feature: tf.Tensor with shape [T, F]
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
normalized audio features with shape [T, F]
|
74 |
+
"""
|
75 |
+
axis = 0 if per_feature else None
|
76 |
+
mean = tf.reduce_mean(audio_feature, axis=axis)
|
77 |
+
std_dev = tf.math.reduce_std(audio_feature, axis=axis) + 1e-9
|
78 |
+
return (audio_feature - mean) / std_dev
|
79 |
+
|
80 |
+
|
81 |
+
def normalize_signal(signal: np.ndarray):
|
82 |
+
""" Normailize signal to [-1, 1] range """
|
83 |
+
gain = 1.0 / (np.max(np.abs(signal)) + 1e-9)
|
84 |
+
return signal * gain
|
85 |
+
|
86 |
+
|
87 |
+
def tf_normalize_signal(signal: tf.Tensor):
|
88 |
+
"""
|
89 |
+
TF Normailize signal to [-1, 1] range
|
90 |
+
Args:
|
91 |
+
signal: tf.Tensor with shape [None]
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
normalized signal with shape [None]
|
95 |
+
"""
|
96 |
+
gain = 1.0 / (tf.reduce_max(tf.abs(signal), axis=-1) + 1e-9)
|
97 |
+
return signal * gain
|
98 |
+
|
99 |
+
|
100 |
+
def preemphasis(signal: np.ndarray, coeff=0.97):
|
101 |
+
if not coeff or coeff <= 0.0:
|
102 |
+
return signal
|
103 |
+
return np.append(signal[0], signal[1:] - coeff * signal[:-1])
|
104 |
+
|
105 |
+
|
106 |
+
def tf_preemphasis(signal: tf.Tensor, coeff=0.97):
|
107 |
+
"""
|
108 |
+
TF Pre-emphasis
|
109 |
+
Args:
|
110 |
+
signal: tf.Tensor with shape [None]
|
111 |
+
coeff: Float that indicates the preemphasis coefficient
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
pre-emphasized signal with shape [None]
|
115 |
+
"""
|
116 |
+
if not coeff or coeff <= 0.0: return signal
|
117 |
+
s0 = tf.expand_dims(signal[0], axis=-1)
|
118 |
+
s1 = signal[1:] - coeff * signal[:-1]
|
119 |
+
return tf.concat([s0, s1], axis=-1)
|
120 |
+
|
121 |
+
|
122 |
+
def depreemphasis(signal: np.ndarray, coeff=0.97):
|
123 |
+
if not coeff or coeff <= 0.0: return signal
|
124 |
+
x = np.zeros(signal.shape[0], dtype=np.float32)
|
125 |
+
x[0] = signal[0]
|
126 |
+
for n in range(1, signal.shape[0], 1):
|
127 |
+
x[n] = coeff * x[n - 1] + signal[n]
|
128 |
+
return x
|
129 |
+
|
130 |
+
|
131 |
+
def tf_depreemphasis(signal: tf.Tensor, coeff=0.97):
|
132 |
+
"""
|
133 |
+
TF Depreemphasis
|
134 |
+
Args:
|
135 |
+
signal: tf.Tensor with shape [B, None]
|
136 |
+
coeff: Float that indicates the preemphasis coefficient
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
depre-emphasized signal with shape [B, None]
|
140 |
+
"""
|
141 |
+
if not coeff or coeff <= 0.0: return signal
|
142 |
+
|
143 |
+
def map_fn(elem):
|
144 |
+
x = tf.expand_dims(elem[0], axis=-1)
|
145 |
+
for n in range(1, elem.shape[0], 1):
|
146 |
+
current = coeff * x[n - 1] + elem[n]
|
147 |
+
x = tf.concat([x, [current]], axis=0)
|
148 |
+
return x
|
149 |
+
|
150 |
+
return tf.map_fn(map_fn, signal)
|
151 |
+
|
152 |
+
|
153 |
+
class SpeechFeaturizer(metaclass=abc.ABCMeta):
|
154 |
+
def __init__(self, speech_config: dict):
|
155 |
+
"""
|
156 |
+
We should use TFSpeechFeaturizer for training to avoid differences
|
157 |
+
between tf and librosa when converting to tflite in post-training stage
|
158 |
+
speech_config = {
|
159 |
+
"sample_rate": int,
|
160 |
+
"frame_ms": int,
|
161 |
+
"stride_ms": int,
|
162 |
+
"num_feature_bins": int,
|
163 |
+
"feature_type": str,
|
164 |
+
"delta": bool,
|
165 |
+
"delta_delta": bool,
|
166 |
+
"pitch": bool,
|
167 |
+
"normalize_signal": bool,
|
168 |
+
"normalize_feature": bool,
|
169 |
+
"normalize_per_feature": bool
|
170 |
+
}
|
171 |
+
"""
|
172 |
+
# Samples
|
173 |
+
self.sample_rate = speech_config.get("sample_rate", 16000)
|
174 |
+
self.frame_length = int(self.sample_rate * (speech_config.get("frame_ms", 25) / 1000))
|
175 |
+
self.frame_step = int(self.sample_rate * (speech_config.get("stride_ms", 10) / 1000))
|
176 |
+
# Features
|
177 |
+
self.num_feature_bins = speech_config.get("num_feature_bins", 80)
|
178 |
+
self.feature_type = speech_config.get("feature_type", "log_mel_spectrogram")
|
179 |
+
self.preemphasis = speech_config.get("preemphasis", None)
|
180 |
+
# Normalization
|
181 |
+
self.normalize_signal = speech_config.get("normalize_signal", True)
|
182 |
+
self.normalize_feature = speech_config.get("normalize_feature", True)
|
183 |
+
self.normalize_per_feature = speech_config.get("normalize_per_feature", False)
|
184 |
+
# librosa mel filter
|
185 |
+
self.mel_filter = None
|
186 |
+
|
187 |
+
@property
|
188 |
+
def nfft(self) -> int:
|
189 |
+
""" Number of FFT """
|
190 |
+
return 2 ** (self.frame_length - 1).bit_length()
|
191 |
+
|
192 |
+
@property
|
193 |
+
def shape(self) -> list:
|
194 |
+
""" The shape of extracted features """
|
195 |
+
raise NotImplementedError()
|
196 |
+
|
197 |
+
@abc.abstractclassmethod
|
198 |
+
def stft(self, signal):
|
199 |
+
raise NotImplementedError()
|
200 |
+
|
201 |
+
@abc.abstractclassmethod
|
202 |
+
def power_to_db(self, S, ref=1.0, amin=1e-10, top_db=80.0):
|
203 |
+
raise NotImplementedError()
|
204 |
+
|
205 |
+
@abc.abstractmethod
|
206 |
+
def extract(self, signal):
|
207 |
+
""" Function to perform feature extraction """
|
208 |
+
raise NotImplementedError()
|
209 |
+
|
210 |
+
|
211 |
+
class NumpySpeechFeaturizer(SpeechFeaturizer):
|
212 |
+
def __init__(self, speech_config: dict):
|
213 |
+
super(NumpySpeechFeaturizer, self).__init__(speech_config)
|
214 |
+
self.delta = speech_config.get("delta", False)
|
215 |
+
self.delta_delta = speech_config.get("delta_delta", False)
|
216 |
+
self.pitch = speech_config.get("pitch", False)
|
217 |
+
|
218 |
+
@property
|
219 |
+
def shape(self) -> list:
|
220 |
+
# None for time dimension
|
221 |
+
channel_dim = 1
|
222 |
+
|
223 |
+
if self.delta:
|
224 |
+
channel_dim += 1
|
225 |
+
|
226 |
+
if self.delta_delta:
|
227 |
+
channel_dim += 1
|
228 |
+
|
229 |
+
if self.pitch:
|
230 |
+
channel_dim += 1
|
231 |
+
|
232 |
+
return [None, self.num_feature_bins, channel_dim]
|
233 |
+
|
234 |
+
def stft(self, signal):
|
235 |
+
return np.square(
|
236 |
+
np.abs(librosa.core.stft(signal, n_fft=self.nfft, hop_length=self.frame_step,
|
237 |
+
win_length=self.frame_length, center=True, window="hann")))
|
238 |
+
|
239 |
+
def power_to_db(self, S, ref=1.0, amin=1e-10, top_db=80.0):
|
240 |
+
return librosa.power_to_db(S, ref=ref, amin=amin, top_db=top_db)
|
241 |
+
|
242 |
+
def extract(self, signal: np.ndarray) -> np.ndarray:
|
243 |
+
signal = np.asfortranarray(signal)
|
244 |
+
if self.normalize_signal:
|
245 |
+
signal = normalize_signal(signal)
|
246 |
+
signal = preemphasis(signal, self.preemphasis)
|
247 |
+
|
248 |
+
if self.feature_type == "mfcc":
|
249 |
+
features = self.compute_mfcc(signal)
|
250 |
+
elif self.feature_type == "log_mel_spectrogram":
|
251 |
+
features = self.compute_log_mel_spectrogram(signal)
|
252 |
+
elif self.feature_type == "spectrogram":
|
253 |
+
features = self.compute_spectrogram(signal)
|
254 |
+
elif self.feature_type == "log_gammatone_spectrogram":
|
255 |
+
features = self.compute_log_gammatone_spectrogram(signal)
|
256 |
+
else:
|
257 |
+
raise ValueError("feature_type must be either 'mfcc', "
|
258 |
+
"'log_mel_spectrogram', 'log_gammatone_spectrogram' "
|
259 |
+
"or 'spectrogram'")
|
260 |
+
|
261 |
+
if self.normalize_feature:
|
262 |
+
features = normalize_audio_feature(features, per_feature=self.normalize_per_feature)
|
263 |
+
|
264 |
+
# features = np.expand_dims(features, axis=-1)
|
265 |
+
|
266 |
+
return features
|
267 |
+
|
268 |
+
def compute_pitch(self, signal: np.ndarray) -> np.ndarray:
|
269 |
+
pitches, _ = librosa.core.piptrack(
|
270 |
+
y=signal, sr=self.sample_rate,
|
271 |
+
n_fft=self.nfft, hop_length=self.frame_step,
|
272 |
+
fmin=0.0, fmax=int(self.sample_rate / 2), win_length=self.frame_length, center=True
|
273 |
+
)
|
274 |
+
|
275 |
+
pitches = pitches.T
|
276 |
+
|
277 |
+
assert self.num_feature_bins <= self.frame_length // 2 + 1, \
|
278 |
+
"num_features for spectrogram should \
|
279 |
+
be <= (sample_rate * window_size // 2 + 1)"
|
280 |
+
|
281 |
+
return pitches[:, :self.num_feature_bins]
|
282 |
+
|
283 |
+
def compute_spectrogram(self, signal: np.ndarray) -> np.ndarray:
|
284 |
+
powspec = self.stft(signal)
|
285 |
+
features = self.power_to_db(powspec.T)
|
286 |
+
|
287 |
+
assert self.num_feature_bins <= self.frame_length // 2 + 1, \
|
288 |
+
"num_features for spectrogram should \
|
289 |
+
be <= (sample_rate * window_size // 2 + 1)"
|
290 |
+
|
291 |
+
# cut high frequency part, keep num_feature_bins features
|
292 |
+
features = features[:, :self.num_feature_bins]
|
293 |
+
|
294 |
+
return features
|
295 |
+
|
296 |
+
def compute_mfcc(self, signal: np.ndarray) -> np.ndarray:
|
297 |
+
S = self.stft(signal)
|
298 |
+
|
299 |
+
mel = librosa.filters.mel(self.sample_rate, self.nfft,
|
300 |
+
n_mels=self.num_feature_bins,
|
301 |
+
fmin=0.0, fmax=int(self.sample_rate / 2))
|
302 |
+
|
303 |
+
mel_spectrogram = np.dot(S.T, mel.T)
|
304 |
+
|
305 |
+
mfcc = librosa.feature.mfcc(sr=self.sample_rate,
|
306 |
+
S=self.power_to_db(mel_spectrogram).T,
|
307 |
+
n_mfcc=self.num_feature_bins)
|
308 |
+
|
309 |
+
return mfcc.T
|
310 |
+
|
311 |
+
def compute_log_mel_spectrogram(self, signal: np.ndarray) -> np.ndarray:
|
312 |
+
S = self.stft(signal)
|
313 |
+
|
314 |
+
mel = librosa.filters.mel(self.sample_rate, self.nfft,
|
315 |
+
n_mels=self.num_feature_bins,
|
316 |
+
fmin=0.0, fmax=int(self.sample_rate / 2))
|
317 |
+
|
318 |
+
mel_spectrogram = np.dot(S.T, mel.T)
|
319 |
+
|
320 |
+
return self.power_to_db(mel_spectrogram)
|
321 |
+
|
322 |
+
def compute_log_gammatone_spectrogram(self, signal: np.ndarray) -> np.ndarray:
|
323 |
+
S = self.stft(signal)
|
324 |
+
|
325 |
+
gammatone = fft_weights(self.nfft, self.sample_rate,
|
326 |
+
self.num_feature_bins, width=1.0,
|
327 |
+
fmin=0, fmax=int(self.sample_rate / 2),
|
328 |
+
maxlen=(self.nfft / 2 + 1))
|
329 |
+
|
330 |
+
gammatone = gammatone.numpy().astype(np.float32)
|
331 |
+
|
332 |
+
gammatone_spectrogram = np.dot(S.T, gammatone)
|
333 |
+
|
334 |
+
return self.power_to_db(gammatone_spectrogram)
|
335 |
+
|
336 |
+
|
337 |
+
class TFSpeechFeaturizer(SpeechFeaturizer):
|
338 |
+
@property
|
339 |
+
def shape(self) -> list:
|
340 |
+
# None for time dimension
|
341 |
+
return [None, self.num_feature_bins, 1]
|
342 |
+
|
343 |
+
def stft(self, signal):
|
344 |
+
signal = tf.pad(signal, [[self.nfft // 2, self.nfft // 2]], mode="REFLECT")
|
345 |
+
window = tf.signal.hann_window(self.frame_length, periodic=True)
|
346 |
+
left_pad = (self.nfft - self.frame_length) // 2
|
347 |
+
right_pad = self.nfft - self.frame_length - left_pad
|
348 |
+
window = tf.pad(window, [[left_pad, right_pad]])
|
349 |
+
framed_signals = tf.signal.frame(signal, frame_length=self.nfft, frame_step=self.frame_step)
|
350 |
+
framed_signals *= window
|
351 |
+
return tf.square(tf.abs(tf.signal.rfft(framed_signals, [self.nfft])))
|
352 |
+
|
353 |
+
def power_to_db(self, S, ref=1.0, amin=1e-10, top_db=80.0):
|
354 |
+
if amin <= 0:
|
355 |
+
raise ValueError('amin must be strictly positive')
|
356 |
+
|
357 |
+
magnitude = S
|
358 |
+
|
359 |
+
if six.callable(ref):
|
360 |
+
# User supplied a function to calculate reference power
|
361 |
+
ref_value = ref(magnitude)
|
362 |
+
else:
|
363 |
+
ref_value = np.abs(ref)
|
364 |
+
|
365 |
+
log_spec = 10.0 * log10(tf.maximum(amin, magnitude))
|
366 |
+
log_spec -= 10.0 * log10(tf.maximum(amin, ref_value))
|
367 |
+
|
368 |
+
if top_db is not None:
|
369 |
+
if top_db < 0:
|
370 |
+
raise ValueError('top_db must be non-negative')
|
371 |
+
log_spec = tf.maximum(log_spec, tf.reduce_max(log_spec) - top_db)
|
372 |
+
|
373 |
+
return log_spec
|
374 |
+
|
375 |
+
def extract(self, signal: np.ndarray) -> np.ndarray:
|
376 |
+
signal = np.asfortranarray(signal)
|
377 |
+
features = self.tf_extract(tf.convert_to_tensor(signal, dtype=tf.float32))
|
378 |
+
return features.numpy()
|
379 |
+
|
380 |
+
def tf_extract(self, signal: tf.Tensor) -> tf.Tensor:
|
381 |
+
"""
|
382 |
+
Extract speech features from signals (for using in tflite)
|
383 |
+
Args:
|
384 |
+
signal: tf.Tensor with shape [None]
|
385 |
+
|
386 |
+
Returns:
|
387 |
+
features: tf.Tensor with shape [T, F]
|
388 |
+
"""
|
389 |
+
if self.normalize_signal:
|
390 |
+
signal = tf_normalize_signal(signal)
|
391 |
+
signal = tf_preemphasis(signal, self.preemphasis)
|
392 |
+
|
393 |
+
if self.feature_type == "spectrogram":
|
394 |
+
features = self.compute_spectrogram(signal)
|
395 |
+
elif self.feature_type == "log_mel_spectrogram":
|
396 |
+
features = self.compute_log_mel_spectrogram(signal)
|
397 |
+
elif self.feature_type == "mfcc":
|
398 |
+
features = self.compute_mfcc(signal)
|
399 |
+
elif self.feature_type == "log_gammatone_spectrogram":
|
400 |
+
features = self.compute_log_gammatone_spectrogram(signal)
|
401 |
+
else:
|
402 |
+
raise ValueError("feature_type must be either 'mfcc',"
|
403 |
+
"'log_mel_spectrogram' or 'spectrogram'")
|
404 |
+
|
405 |
+
if self.normalize_feature:
|
406 |
+
features = tf_normalize_audio_features(
|
407 |
+
features, per_feature=self.normalize_per_feature)
|
408 |
+
|
409 |
+
# features = tf.expand_dims(features, axis=-1)
|
410 |
+
|
411 |
+
return features
|
412 |
+
|
413 |
+
def compute_log_mel_spectrogram(self, signal):
|
414 |
+
spectrogram = self.stft(signal)
|
415 |
+
if self.mel_filter is None:
|
416 |
+
linear_to_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
|
417 |
+
num_mel_bins=self.num_feature_bins,
|
418 |
+
num_spectrogram_bins=spectrogram.shape[-1],
|
419 |
+
sample_rate=self.sample_rate,
|
420 |
+
lower_edge_hertz=0.0, upper_edge_hertz=(self.sample_rate / 2)
|
421 |
+
)
|
422 |
+
else:
|
423 |
+
linear_to_weight_matrix = self.mel_filter
|
424 |
+
|
425 |
+
mel_spectrogram = tf.tensordot(spectrogram, linear_to_weight_matrix, 1)
|
426 |
+
return self.power_to_db(mel_spectrogram)
|
427 |
+
|
428 |
+
def compute_spectrogram(self, signal):
|
429 |
+
S = self.stft(signal)
|
430 |
+
spectrogram = self.power_to_db(S)
|
431 |
+
return spectrogram[:, :self.num_feature_bins]
|
432 |
+
|
433 |
+
def compute_mfcc(self, signal):
|
434 |
+
log_mel_spectrogram = self.compute_log_mel_spectrogram(signal)
|
435 |
+
return tf.signal.mfccs_from_log_mel_spectrograms(log_mel_spectrogram)
|
436 |
+
|
437 |
+
def compute_log_gammatone_spectrogram(self, signal: np.ndarray) -> np.ndarray:
|
438 |
+
S = self.stft(signal)
|
439 |
+
|
440 |
+
gammatone = fft_weights(self.nfft, self.sample_rate,
|
441 |
+
self.num_feature_bins, width=1.0,
|
442 |
+
fmin=0, fmax=int(self.sample_rate / 2),
|
443 |
+
maxlen=(self.nfft / 2 + 1))
|
444 |
+
|
445 |
+
gammatone_spectrogram = tf.tensordot(S, gammatone, 1)
|
446 |
+
|
447 |
+
return self.power_to_db(gammatone_spectrogram)
|
448 |
+
|
449 |
+
def set_mel_filter(self, librosa_mel_filter):
|
450 |
+
"""
|
451 |
+
Set librosa mel filter.
|
452 |
+
"""
|
453 |
+
self.mel_filter = librosa_mel_filter
|
librosa_mel_filter.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/__init__.py
ADDED
File without changes
|
models/layers/__init__.py
ADDED
File without changes
|
models/layers/attention.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import tensorflow as tf
|
3 |
+
|
4 |
+
|
5 |
+
class Attention(tf.keras.layers.Layer):
|
6 |
+
def __init__(self, hidden_size,
|
7 |
+
attention_size=1,
|
8 |
+
name=None,
|
9 |
+
**kwargs):
|
10 |
+
super().__init__( **kwargs)
|
11 |
+
self.w_kernel = self.add_variable('w_kernel', [hidden_size, attention_size])
|
12 |
+
self.w_bias = self.add_variable('w_bias', [attention_size])
|
13 |
+
self.bias = self.add_variable('bias', [attention_size])
|
14 |
+
|
15 |
+
|
16 |
+
def call(self, inputs, inp_len, maxlen=150, mask=None, training=False, **kwargs):
|
17 |
+
"""
|
18 |
+
inp_len: length of input audio
|
19 |
+
maxlen: audio length after downsampling(cnn(twice downsample) and maxpool), in our experiments
|
20 |
+
the input length is 1200s, after downsampling, the sequence length is 1200//8=1500,
|
21 |
+
(8=2*2*2, see model parameters for details).
|
22 |
+
If you change input length and times of dowansampling,
|
23 |
+
please reset the maxlen parameter!!!!
|
24 |
+
"""
|
25 |
+
# In case of Bi-RNN, concatenate the forward and the backward Rnn outputs.
|
26 |
+
if isinstance(inputs, tuple):
|
27 |
+
inputs = tf.concat(inputs, 2)
|
28 |
+
v = tf.sigmoid(tf.tensordot(inputs, self.w_kernel, axes=1) + self.w_bias)
|
29 |
+
vu = tf.tensordot(v, self.bias, axes=1)
|
30 |
+
alphas = tf.nn.softmax(vu) #(B,T)
|
31 |
+
if mask is not None:
|
32 |
+
alphas = alphas*tf.cast(tf.sequence_mask(inp_len, maxlen), dtype=tf.float32)
|
33 |
+
output = tf.reduce_sum(inputs*tf.expand_dims(alphas, -1), 1)
|
34 |
+
|
35 |
+
return output
|
models/model.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from featurizers.speech_featurizers import SpeechFeaturizer
|
3 |
+
from .layers.attention import Attention
|
4 |
+
|
5 |
+
|
6 |
+
L2 = tf.keras.regularizers.l2(1e-6)
|
7 |
+
|
8 |
+
|
9 |
+
def shape_list(x, out_type=tf.int32):
|
10 |
+
"""Deal with dynamic shape in tensorflow cleanly."""
|
11 |
+
static = x.shape.as_list()
|
12 |
+
dynamic = tf.shape(x, out_type=out_type)
|
13 |
+
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
|
14 |
+
|
15 |
+
|
16 |
+
def merge_two_last_dims(x):
|
17 |
+
b, _, f, c = shape_list(x)
|
18 |
+
return tf.reshape(x, shape=[b, -1, f * c])
|
19 |
+
|
20 |
+
|
21 |
+
class MulSpeechLR(tf.keras.Model):
|
22 |
+
def __init__(self, name, filters, kernel_size, d_model, rnn_cell, seq_mask, vocab_size, dropout=0.5):
|
23 |
+
super(MulSpeechLR, self).__init__()
|
24 |
+
self.filters1 = filters[0]
|
25 |
+
self.filters2 = filters[1]
|
26 |
+
self.filters3 = filters[2]
|
27 |
+
self.kernel_size1 = kernel_size[0]
|
28 |
+
self.kernel_size2 = kernel_size[1]
|
29 |
+
self.kernel_size3 = kernel_size[2]
|
30 |
+
#during training, self.mask can be set true, but during inference, it must be false
|
31 |
+
self.mask = seq_mask
|
32 |
+
self.conv1 = tf.keras.layers.Conv2D(filters=self.filters1, kernel_size=self.kernel_size1,
|
33 |
+
strides=(2,2), padding='same', activation='relu')
|
34 |
+
self.maxpool1 = tf.keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2))
|
35 |
+
|
36 |
+
self.conv2 = tf.keras.layers.Conv2D(filters=self.filters2, kernel_size=self.kernel_size2,
|
37 |
+
strides=(2,2), padding='same', activation='relu')
|
38 |
+
self.conv3 = tf.keras.layers.Conv2D(filters=self.filters3, kernel_size=self.kernel_size3,
|
39 |
+
strides=(1,1), padding='same', activation='relu')
|
40 |
+
self.ln1 = tf.keras.layers.LayerNormalization(name=f"{name}_ln_1")
|
41 |
+
self.ln2 = tf.keras.layers.LayerNormalization(name=f"{name}_ln_2")
|
42 |
+
self.ln3 = tf.keras.layers.LayerNormalization(name=f"{name}_ln_3")
|
43 |
+
# self.linear1 = tf.keras.layers.Dense(d_model*2, name=f"{name}_dense_1")
|
44 |
+
self.linear2 = tf.keras.layers.Dense(d_model, name=f"{name}_dense_2")
|
45 |
+
self.rnn = tf.keras.layers.GRU(rnn_cell, return_sequences=True, return_state=True, name=f"{name}_gru")
|
46 |
+
self.attention = Attention(rnn_cell)
|
47 |
+
self.class_layer = tf.keras.layers.Dense(vocab_size)
|
48 |
+
self.res_add = tf.keras.layers.Add(name=f"{name}_add")
|
49 |
+
|
50 |
+
|
51 |
+
def call(self, inputs):
|
52 |
+
x, x_len = inputs
|
53 |
+
# mask = tf.cast(tf.sequence_mask(x_len, maxlen=150), dtype=tf.float32)
|
54 |
+
x = tf.expand_dims(x, axis=-1)
|
55 |
+
x = self.conv1(x)
|
56 |
+
x = self.ln1(x)
|
57 |
+
x = self.maxpool1(x)
|
58 |
+
x = self.conv2(x)
|
59 |
+
x = self.ln2(x)
|
60 |
+
x = self.conv3(x)
|
61 |
+
x = self.ln3(x)
|
62 |
+
x = merge_two_last_dims(x)
|
63 |
+
x, final_state = self.rnn(x)
|
64 |
+
x = self.attention(x, x_len, self.mask)
|
65 |
+
x = self.res_add([x, final_state])
|
66 |
+
output = self.linear2(x)
|
67 |
+
output = tf.nn.relu(output)
|
68 |
+
output = self.class_layer(output)
|
69 |
+
|
70 |
+
return output
|
71 |
+
|
72 |
+
|
73 |
+
def init_build(self, input_shape):
|
74 |
+
x = tf.keras.Input(shape=input_shape, dtype= tf.float32)
|
75 |
+
l = tf.keras.Input(shape=[], dtype=tf.int32)
|
76 |
+
self([x, l], training=False)
|
77 |
+
|
78 |
+
def add_featurizers(self,
|
79 |
+
speech_featurizer: SpeechFeaturizer):
|
80 |
+
"""
|
81 |
+
Function to add featurizer to model to convert to end2end tflite
|
82 |
+
Args:
|
83 |
+
speech_featurizer: SpeechFeaturizer instance
|
84 |
+
"""
|
85 |
+
self.speech_featurizer = speech_featurizer
|
86 |
+
|
87 |
+
|
88 |
+
@tf.function(input_signature=[tf.TensorSpec([None], dtype=tf.float32)])
|
89 |
+
def predict_pb(self, signal):
|
90 |
+
features = self.speech_featurizer.tf_extract(signal)
|
91 |
+
input_len = tf.expand_dims(tf.shape(features)[0], axis=0)
|
92 |
+
input = tf.expand_dims(features, axis=0)
|
93 |
+
output = self([input, input_len], training=False)
|
94 |
+
output = tf.nn.softmax(output)
|
95 |
+
output1 = tf.squeeze(output)
|
96 |
+
output = tf.argmax(output1, axis=-1)
|
97 |
+
|
98 |
+
return output, tf.gather(output1, output)
|
optimizers/__init.py
ADDED
File without changes
|
optimizers/schedules.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 by zhongying
|
2 |
+
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.python.framework import ops
|
5 |
+
from tensorflow.python.ops import math_ops
|
6 |
+
from tensorflow.keras.optimizers.schedules import ExponentialDecay
|
7 |
+
|
8 |
+
|
9 |
+
class TransformerLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
|
10 |
+
""" Transformer learning rate schedule """
|
11 |
+
|
12 |
+
def __init__(self, d_model, init_steps=0, warmup_steps=4000, max_lr=None):
|
13 |
+
super(TransformerLRSchedule, self).__init__()
|
14 |
+
|
15 |
+
self.d_model = d_model
|
16 |
+
self.d_model = tf.cast(self.d_model, tf.float32)
|
17 |
+
self.max_lr = max_lr
|
18 |
+
self.warmup_steps = warmup_steps
|
19 |
+
self.init_steps = init_steps
|
20 |
+
|
21 |
+
def __call__(self, step):
|
22 |
+
# lr = (d_model^-0.5) * min(step^-0.5, step*(warm_up^-1.5))
|
23 |
+
step += self.init_steps
|
24 |
+
arg1 = tf.math.rsqrt(step)
|
25 |
+
arg2 = step * (self.warmup_steps ** -1.5)
|
26 |
+
lr = tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
|
27 |
+
if self.max_lr is not None:
|
28 |
+
return tf.math.minimum(self.max_lr, lr)
|
29 |
+
return lr
|
30 |
+
|
31 |
+
def get_config(self):
|
32 |
+
return {
|
33 |
+
"d_model": self.d_model,
|
34 |
+
"warmup_steps": self.warmup_steps,
|
35 |
+
"max_lr": self.max_lr
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
class SANSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
|
40 |
+
def __init__(self, lamb, d_model, warmup_steps=4000):
|
41 |
+
super(SANSchedule, self).__init__()
|
42 |
+
|
43 |
+
self.lamb = tf.cast(lamb, tf.float32)
|
44 |
+
self.d_model = tf.cast(d_model, tf.float32)
|
45 |
+
|
46 |
+
self.warmup_steps = tf.cast(warmup_steps, tf.float32)
|
47 |
+
|
48 |
+
def __call__(self, step):
|
49 |
+
arg1 = step / (self.warmup_steps ** 1.5)
|
50 |
+
arg2 = 1 / tf.math.sqrt(step)
|
51 |
+
|
52 |
+
return (self.lamb / tf.math.sqrt(self.d_model)) * tf.math.minimum(arg1, arg2)
|
53 |
+
|
54 |
+
def get_config(self):
|
55 |
+
return {
|
56 |
+
"lamb": self.lamb,
|
57 |
+
"d_model": self.d_model,
|
58 |
+
"warmup_steps": self.warmup_steps
|
59 |
+
}
|
60 |
+
|
61 |
+
|
62 |
+
class BoundExponentialDecay(ExponentialDecay):
|
63 |
+
def __init__(self, min_lr=0.0, **kwargs):
|
64 |
+
super().__init__(**kwargs)
|
65 |
+
self.min_lr = min_lr
|
66 |
+
|
67 |
+
def __call__(self, step):
|
68 |
+
with ops.name_scope_v2(self.name or "ExponentialDecay") as name:
|
69 |
+
initial_learning_rate = ops.convert_to_tensor(
|
70 |
+
self.initial_learning_rate, name="initial_learning_rate")
|
71 |
+
dtype = initial_learning_rate.dtype
|
72 |
+
decay_steps = math_ops.cast(self.decay_steps, dtype)
|
73 |
+
decay_rate = math_ops.cast(self.decay_rate, dtype)
|
74 |
+
|
75 |
+
global_step_recomp = math_ops.cast(step, dtype)
|
76 |
+
p = global_step_recomp / decay_steps
|
77 |
+
if self.staircase:
|
78 |
+
p = math_ops.floor(p)
|
79 |
+
new_lr = math_ops.multiply(
|
80 |
+
initial_learning_rate, math_ops.pow(decay_rate, p), name=name)
|
81 |
+
return math_ops.maximum(self.min_lr, new_lr)
|
predict_by_pb.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from signal import signal
|
2 |
+
import tensorflow as tf
|
3 |
+
gpus = tf.config.list_physical_devices('GPU')
|
4 |
+
tf.config.set_visible_devices(gpus[0:1], 'GPU')
|
5 |
+
from vocab.vocab import Vocab
|
6 |
+
import librosa
|
7 |
+
import numpy as np
|
8 |
+
import sys
|
9 |
+
import os
|
10 |
+
from tqdm import tqdm
|
11 |
+
from sklearn.metrics import accuracy_score
|
12 |
+
|
13 |
+
|
14 |
+
vocab = Vocab("vocab/vocab.txt")
|
15 |
+
model = tf.saved_model.load('saved_models/lang14/pb/2/')
|
16 |
+
|
17 |
+
|
18 |
+
def predict_wav(wav_path):
|
19 |
+
signal, _ = librosa.load(wav_path, sr=16000)
|
20 |
+
output, prob = model.predict_pb(signal)
|
21 |
+
language = vocab.token_list[output.numpy()]
|
22 |
+
print(language, prob.numpy()*100)
|
23 |
+
|
24 |
+
return output.numpy(), prob.numpy()
|
25 |
+
|
26 |
+
|
27 |
+
if __name__ == '__main__':
|
28 |
+
wav_path = sys.argv[1]
|
29 |
+
predict_wav(wav_path)
|
30 |
+
|
predict_by_weights.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
gpus = tf.config.list_physical_devices('GPU')
|
3 |
+
tf.config.set_visible_devices(gpus[0:1], 'GPU')
|
4 |
+
from vocab.vocab import Vocab
|
5 |
+
from dataset import create_dataset
|
6 |
+
from configs.config import Config
|
7 |
+
import sys
|
8 |
+
from featurizers.speech_featurizers import TFSpeechFeaturizer, NumpySpeechFeaturizer
|
9 |
+
from models.model import MulSpeechLR as Model
|
10 |
+
import librosa
|
11 |
+
|
12 |
+
|
13 |
+
weights_dir = './saved_weights/20230228-084356/'
|
14 |
+
config_file = weights_dir + 'config.yml'
|
15 |
+
model_file = weights_dir + 'last/model'
|
16 |
+
vocab_file = weights_dir + 'vocab.txt'
|
17 |
+
config = Config(config_file)
|
18 |
+
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
|
19 |
+
lr_vocab = Vocab(vocab_file)
|
20 |
+
lr_model = Model(**config.model_config, vocab_size=len(lr_vocab.token_list))
|
21 |
+
lr_model.load_weights(model_file)
|
22 |
+
lr_model.add_featurizers(speech_featurizer)
|
23 |
+
lr_model.init_build([None, config.speech_config['num_feature_bins']])
|
24 |
+
lr_model.summary()
|
25 |
+
|
26 |
+
|
27 |
+
def predict_wav(wav_path):
|
28 |
+
sample_rate = 16000
|
29 |
+
signal, _ = librosa.load(wav_path, sr=sample_rate)
|
30 |
+
predict, prob = lr_model.predict_pb(signal)
|
31 |
+
language = lr_vocab.token_list[predict.numpy()]
|
32 |
+
print("predict language={} prob={:.4f}".format(language, prob.numpy()*100))
|
33 |
+
|
34 |
+
if __name__ == '__main__':
|
35 |
+
wav_path = sys.argv[1]
|
36 |
+
predict_wav(wav_path)
|
train.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# copyright by speechflow 2023/03/17
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import tensorflow as tf
|
6 |
+
gpus = tf.config.list_physical_devices('GPU')
|
7 |
+
# tf.config.set_visible_devices(gpus[0:1], 'GPU')
|
8 |
+
import datetime
|
9 |
+
import time
|
10 |
+
import os
|
11 |
+
from shutil import copyfile
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
from vocab.vocab import Vocab
|
14 |
+
from configs.config import Config
|
15 |
+
from models.model import MulSpeechLR as Model
|
16 |
+
from termcolor import colored
|
17 |
+
from featurizers.speech_featurizers import NumpySpeechFeaturizer
|
18 |
+
from dataset import create_dataset
|
19 |
+
import tensorflow_addons as tfa
|
20 |
+
from sklearn.metrics import f1_score, recall_score, precision_score
|
21 |
+
mirrored_strategy = tf.distribute.MirroredStrategy()
|
22 |
+
|
23 |
+
|
24 |
+
def train(config_file):
|
25 |
+
config = Config(config_file)
|
26 |
+
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
27 |
+
dir_log_root = "./saved_weights/"
|
28 |
+
if not os.path.exists(dir_log_root):
|
29 |
+
os.mkdir(dir_log_root)
|
30 |
+
dir_current = dir_log_root + current_time
|
31 |
+
if not os.path.isdir(dir_log_root):
|
32 |
+
os.mkdir(dir_log_root)
|
33 |
+
if not os.path.isdir(dir_current):
|
34 |
+
os.mkdir(dir_current)
|
35 |
+
copyfile(config_file, dir_current + '/config.yml')
|
36 |
+
log_file = open(dir_current + '/log.txt', 'w')
|
37 |
+
copyfile(config.dataset_config['vocabulary'], dir_current + '/vocab.txt')
|
38 |
+
|
39 |
+
|
40 |
+
config.print()
|
41 |
+
log_file.write(config.toString())
|
42 |
+
# vocab_file.write(config.toString())
|
43 |
+
log_file.flush()
|
44 |
+
|
45 |
+
vocab = Vocab(config.dataset_config['vocabulary'])
|
46 |
+
batch_size = config.running_config['batch_size']
|
47 |
+
global_batch_size = batch_size * mirrored_strategy.num_replicas_in_sync
|
48 |
+
speech_featurizer = NumpySpeechFeaturizer(config.speech_config)
|
49 |
+
model = Model(**config.model_config, vocab_size=len(vocab.token_list))
|
50 |
+
if config.running_config['load_weights'] is not None:
|
51 |
+
model.load_weights(config.running_config['load_weights'])
|
52 |
+
model.add_featurizers(speech_featurizer)
|
53 |
+
model.init_build([None, config.speech_config['num_feature_bins']])
|
54 |
+
model.summary()
|
55 |
+
|
56 |
+
train_dataset = create_dataset(batch_size=global_batch_size,
|
57 |
+
load_type=config.dataset_config['load_type'],
|
58 |
+
data_type=config.dataset_config['train'],
|
59 |
+
speech_featurizer=speech_featurizer,
|
60 |
+
config = config,
|
61 |
+
vocab = vocab)
|
62 |
+
eval_dataset = create_dataset(batch_size=global_batch_size,
|
63 |
+
load_type=config.dataset_config['load_type'],
|
64 |
+
data_type=config.dataset_config['dev'],
|
65 |
+
speech_featurizer=speech_featurizer,
|
66 |
+
config = config,
|
67 |
+
vocab = vocab)
|
68 |
+
test_dataset = create_dataset(batch_size=global_batch_size,
|
69 |
+
load_type=config.dataset_config['load_type'],
|
70 |
+
data_type=config.dataset_config['test'],
|
71 |
+
speech_featurizer=speech_featurizer,
|
72 |
+
config = config,
|
73 |
+
vocab = vocab)
|
74 |
+
train_dist_batch = mirrored_strategy.experimental_distribute_dataset(train_dataset)
|
75 |
+
dev_dist_batch = mirrored_strategy.experimental_distribute_dataset(eval_dataset)
|
76 |
+
test_dist_batch = mirrored_strategy.experimental_distribute_dataset(test_dataset)
|
77 |
+
dev_loss = tf.keras.metrics.Mean(name='dev_loss')
|
78 |
+
train_loss = tf.keras.metrics.Mean(name='train_loss')
|
79 |
+
dev_accuracy = tf.keras.metrics.Mean(name='train_accuracy')
|
80 |
+
init_steps = config.optimizer_config['init_steps']
|
81 |
+
step = tf.Variable(init_steps)
|
82 |
+
|
83 |
+
optimizer = tf.keras.optimizers.Adam(lr=config.optimizer_config['max_lr'])
|
84 |
+
ckpt = tf.train.Checkpoint(step=step, optimizer=optimizer, model=model)
|
85 |
+
ckpt_manager = tf.train.CheckpointManager(ckpt, dir_current + '/ckpt', max_to_keep=5)
|
86 |
+
loss_object = tfa.losses.SigmoidFocalCrossEntropy(
|
87 |
+
from_logits = True,
|
88 |
+
alpha = 0.25,
|
89 |
+
gamma = 0,
|
90 |
+
reduction = tf.keras.losses.Reduction.NONE)
|
91 |
+
loss_object_label_smooth = tf.keras.losses.CategoricalCrossentropy(
|
92 |
+
from_logits=True, label_smoothing=0.1, reduction=tf.keras.losses.Reduction.NONE)
|
93 |
+
|
94 |
+
def compute_loss(real, pred, smooth=False):
|
95 |
+
if smooth:
|
96 |
+
loss_ = loss_object_label_smooth(tf.one_hot(real, len(vocab.token_list)), pred)
|
97 |
+
else:
|
98 |
+
real = tf.one_hot(real, len(vocab.token_list))
|
99 |
+
loss_ = loss_object(real, pred)
|
100 |
+
return tf.nn.compute_average_loss(loss_, global_batch_size=global_batch_size)
|
101 |
+
|
102 |
+
def accuracy_function(real, pred):
|
103 |
+
pred = tf.cast(pred, dtype=tf.int32)
|
104 |
+
accuracies = tf.equal(real, pred)
|
105 |
+
|
106 |
+
mask = tf.math.logical_not(tf.math.equal(real, 0))
|
107 |
+
accuracies = tf.math.logical_and(mask, accuracies)
|
108 |
+
|
109 |
+
accuracies = tf.cast(accuracies, dtype=tf.float32)
|
110 |
+
mask = tf.cast(mask, dtype=tf.float32)
|
111 |
+
|
112 |
+
return tf.reduce_sum(accuracies)/tf.reduce_sum(mask)
|
113 |
+
|
114 |
+
@tf.function
|
115 |
+
def train_step(input, input_length, target):
|
116 |
+
with tf.GradientTape() as tape:
|
117 |
+
predictions = model([input, input_length], training=True)
|
118 |
+
loss = compute_loss(target, predictions, smooth=True)
|
119 |
+
grads = tape.gradient(loss, model.trainable_variables)
|
120 |
+
optimizer.apply_gradients(zip(grads, model.trainable_variables))
|
121 |
+
return loss
|
122 |
+
|
123 |
+
@tf.function
|
124 |
+
def dev_step(input, input_length, target):
|
125 |
+
predictions = model([input, input_length], training=False)
|
126 |
+
t_loss = compute_loss(target, predictions, smooth=True)
|
127 |
+
|
128 |
+
return t_loss, predictions
|
129 |
+
|
130 |
+
@tf.function
|
131 |
+
def test_step(input, input_length, target):
|
132 |
+
predictions = model([input, input_length], training=False)
|
133 |
+
return predictions, target
|
134 |
+
|
135 |
+
@tf.function(experimental_relax_shapes=True)
|
136 |
+
def distributed_train_step(x, x_len, y):
|
137 |
+
per_replica_losses = mirrored_strategy.run(train_step, args=(x, x_len, y))
|
138 |
+
mean_loss = mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
|
139 |
+
return mean_loss
|
140 |
+
|
141 |
+
@tf.function(experimental_relax_shapes=True)
|
142 |
+
def distributed_dev_step(x, x_len, y):
|
143 |
+
per_replica_losses, per_replica_preds = mirrored_strategy.run(dev_step, args=(x, x_len, y))
|
144 |
+
mean_loss = mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
|
145 |
+
return mean_loss, per_replica_preds
|
146 |
+
|
147 |
+
|
148 |
+
@tf.function(experimental_relax_shapes=True)
|
149 |
+
def distributed_test_step(x, x_len, y):
|
150 |
+
return mirrored_strategy.run(test_step, args=(x, x_len, y))
|
151 |
+
|
152 |
+
plot_train_loss = []
|
153 |
+
plot_dev_loss = []
|
154 |
+
plot_acc, plot_precision = [], []
|
155 |
+
best_acc= 0
|
156 |
+
train_iter = iter(train_dist_batch)
|
157 |
+
dev_iter = iter(dev_dist_batch)
|
158 |
+
test_iter = iter(test_dist_batch)
|
159 |
+
|
160 |
+
for epoch in range(1, config.running_config['num_epochs'] + 1):
|
161 |
+
if config.dataset_config['load_type']=='txt':
|
162 |
+
train_iter = iter(train_dist_batch)
|
163 |
+
dev_iter = iter(dev_dist_batch)
|
164 |
+
test_iter = iter(test_dist_batch)
|
165 |
+
start = time.time()
|
166 |
+
# training loop
|
167 |
+
train_loss = 0.0
|
168 |
+
dev_loss = 0.0
|
169 |
+
for train_batches in range(config.running_config['train_steps']):
|
170 |
+
inp, inp_len, target = next(train_iter)
|
171 |
+
train_loss += distributed_train_step(inp, inp_len, target)
|
172 |
+
template = '\rEpoch {} Step {} Loss {:.4f}'
|
173 |
+
print(colored(template.format(
|
174 |
+
epoch, train_batches + 1, train_loss / (train_batches + 1),
|
175 |
+
), 'green'), end='', flush=True)
|
176 |
+
step.assign_add(1)
|
177 |
+
|
178 |
+
# validation loop
|
179 |
+
pred_all = tf.zeros([1], dtype=tf.int32)
|
180 |
+
true_all = tf.zeros([1], dtype=tf.int32)
|
181 |
+
for dev_batches in range(config.running_config['dev_steps']):
|
182 |
+
inp, inp_len, target = next(dev_iter)
|
183 |
+
loss, predicted_result = distributed_dev_step(inp, inp_len, target)
|
184 |
+
dev_loss += loss
|
185 |
+
if mirrored_strategy.num_replicas_in_sync == 1:
|
186 |
+
prediction = tf.nn.softmax(predicted_result)
|
187 |
+
y_pred = tf.argmax(prediction, axis=-1)
|
188 |
+
y_pred = tf.cast(y_pred, dtype=tf.int32)
|
189 |
+
pred_all = tf.concat([pred_all, y_pred], axis=0)
|
190 |
+
true_all = tf.concat([true_all, target], axis=0)
|
191 |
+
else:
|
192 |
+
for i in range(mirrored_strategy.num_replicas_in_sync):
|
193 |
+
predicted_result_per_replica = predicted_result.values[i]
|
194 |
+
y_true = target.values[i]
|
195 |
+
y_pred = tf.argmax(predicted_result_per_replica, axis=-1)
|
196 |
+
y_pred = tf.cast(y_pred, dtype=tf.int32)
|
197 |
+
pred_all = tf.concat([pred_all, y_pred], axis=0)
|
198 |
+
true_all = tf.concat([true_all, y_true], axis=0)
|
199 |
+
dev_accuracy = accuracy_function(true_all, pred_all)
|
200 |
+
|
201 |
+
pred_all = tf.zeros([1], dtype=tf.int32)
|
202 |
+
true_all = tf.zeros([1], dtype=tf.int32)
|
203 |
+
for test_batches in range(config.running_config['test_steps']):
|
204 |
+
inp, inp_len, target = next(test_iter)
|
205 |
+
predicted_result, target_result = distributed_test_step(inp, inp_len, target)
|
206 |
+
if mirrored_strategy.num_replicas_in_sync == 1:
|
207 |
+
prediction = tf.nn.softmax(predicted_result)
|
208 |
+
y_pred =tf.argmax(prediction, axis=-1)
|
209 |
+
y_pred = tf.cast(y_pred, dtype=tf.int32)
|
210 |
+
pred_all = tf.concat([pred_all, y_pred], axis=0)
|
211 |
+
true_all = tf.concat([true_all, target], axis=0)
|
212 |
+
else:
|
213 |
+
for replica in range(mirrored_strategy.num_replicas_in_sync):
|
214 |
+
predicted_result_per_replica = predicted_result.values[i]
|
215 |
+
y_true = target.values[i]
|
216 |
+
y_pred = tf.argmax(predicted_result_per_replica, axis=-1)
|
217 |
+
y_pred = tf.cast(y_pred, dtype=tf.int32)
|
218 |
+
pred_all = tf.concat([pred_all, y_pred], axis=0)
|
219 |
+
true_all = tf.concat([true_all, y_true], axis=0)
|
220 |
+
|
221 |
+
test_acc = accuracy_function(real=true_all, pred=pred_all)
|
222 |
+
|
223 |
+
test_f1 = f1_score(y_true=true_all, y_pred=pred_all, average='macro')
|
224 |
+
precision = precision_score(y_true=true_all, y_pred=pred_all, average='macro', zero_division=1)
|
225 |
+
recall = recall_score(y_true=true_all, y_pred=pred_all, average='macro')
|
226 |
+
if precision > best_acc:
|
227 |
+
best_acc = precision
|
228 |
+
model.save_weights(dir_current + '/best/' + 'model')
|
229 |
+
model.save_weights(dir_current + '/last/' + 'model')
|
230 |
+
template = ("\rEpoch {}, Loss: {:.4f}, Val Loss: {:.4f}, "
|
231 |
+
"Val Acc: {:.4f}, test ACC: {:.4f},F1: {:.4f}, precision: {:.4f}, recall: {:.4f}, Time Cost: {:.2f} sec")
|
232 |
+
text = template.format(epoch, train_loss / config.running_config['train_steps'],
|
233 |
+
dev_loss/ config.running_config['dev_steps'], dev_accuracy *100,
|
234 |
+
test_acc*100, test_f1*100, precision*100, recall*100, time.time() - start)
|
235 |
+
print(colored(text, 'cyan'))
|
236 |
+
log_file.write(text)
|
237 |
+
log_file.flush()
|
238 |
+
plot_train_loss.append(train_loss / config.running_config['train_steps'])
|
239 |
+
plot_dev_loss.append(dev_loss / config.running_config['dev_steps'])
|
240 |
+
plot_acc.append(test_acc)
|
241 |
+
plot_precision.append(precision)
|
242 |
+
ckpt_manager.save()
|
243 |
+
|
244 |
+
plt.plot(plot_train_loss, '-r', label='train_loss')
|
245 |
+
plt.title('Train Loss')
|
246 |
+
plt.xlabel('Epochs')
|
247 |
+
plt.savefig(dir_current + '/loss.png')
|
248 |
+
#plot dev
|
249 |
+
plt.clf()
|
250 |
+
plt.plot(plot_dev_loss, '-g', label='dev_loss')
|
251 |
+
plt.title('dev Loss')
|
252 |
+
plt.xlabel('Epochs')
|
253 |
+
plt.savefig(dir_current + '/dev_loss.png')
|
254 |
+
|
255 |
+
# plot acc curve
|
256 |
+
plt.clf()
|
257 |
+
plt.plot(plot_acc, 'b-', label='acc')
|
258 |
+
plt.title('Accuracy')
|
259 |
+
plt.xlabel('Epochs')
|
260 |
+
plt.savefig(dir_current + '/acc.png')
|
261 |
+
# plot f1 curve
|
262 |
+
plt.clf()
|
263 |
+
plt.plot(plot_precision, 'y-', label='f1-score')
|
264 |
+
plt.title('F1')
|
265 |
+
plt.xlabel('Epochs')
|
266 |
+
plt.savefig(dir_current + '/f1-score.png')
|
267 |
+
|
268 |
+
|
269 |
+
if __name__ == "__main__":
|
270 |
+
parser = argparse.ArgumentParser(description="Spoken_language_identification Model training")
|
271 |
+
parser.add_argument("--config_file", type=str, default='./configs/config.yml', help="Config File Path")
|
272 |
+
args = parser.parse_args()
|
273 |
+
kwargs = vars(args)
|
274 |
+
with mirrored_strategy.scope():
|
275 |
+
train(**kwargs)
|
util/__init__.py
ADDED
File without changes
|
util/utils.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 Beijing BluePulse Corp.
|
3 |
+
# Created by Zhang Guanqun on 2020/6/5
|
4 |
+
|
5 |
+
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import os
|
8 |
+
import tensorflow as tf
|
9 |
+
from typing import Union, List
|
10 |
+
import unicodedata
|
11 |
+
|
12 |
+
|
13 |
+
def preprocess_paths(paths: Union[List, str]):
|
14 |
+
if isinstance(paths, list):
|
15 |
+
return [os.path.abspath(os.path.expanduser(path)) for path in paths]
|
16 |
+
return os.path.abspath(os.path.expanduser(paths)) if paths else None
|
17 |
+
|
18 |
+
|
19 |
+
def get_reduced_length(length, reduction_factor):
|
20 |
+
return tf.cast(tf.math.ceil(tf.divide(length, tf.cast(reduction_factor, dtype=length.dtype))), dtype=tf.int32)
|
21 |
+
|
22 |
+
|
23 |
+
def merge_two_last_dims(x):
|
24 |
+
b, _, f, c = shape_list(x)
|
25 |
+
return tf.reshape(x, shape=[b, -1, f * c])
|
26 |
+
|
27 |
+
|
28 |
+
def shape_list(x):
|
29 |
+
"""Deal with dynamic shape in tensorflow cleanly."""
|
30 |
+
static = x.shape.as_list()
|
31 |
+
dynamic = tf.shape(x)
|
32 |
+
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
|
33 |
+
|
34 |
+
|
35 |
+
# draw loss pic
|
36 |
+
def plot_metric(history, metric, pic_file_name):
|
37 |
+
train_metrics = history.history[metric]
|
38 |
+
val_metrics = history.history['val_'+metric]
|
39 |
+
epochs = range(1, len(train_metrics) + 1)
|
40 |
+
plt.plot(epochs, train_metrics, 'bo--')
|
41 |
+
plt.plot(epochs, val_metrics, 'ro-')
|
42 |
+
plt.title('Training and validation '+ metric)
|
43 |
+
plt.xlabel("Epochs")
|
44 |
+
plt.ylabel(metric)
|
45 |
+
plt.legend(["train_"+metric, 'val_'+metric])
|
46 |
+
plt.savefig(pic_file_name)
|
47 |
+
|
48 |
+
|
49 |
+
# against LAS loop decoding
|
50 |
+
def text_no_repeat(s):
|
51 |
+
repeat_times = 0
|
52 |
+
repeat_pattern = ''
|
53 |
+
for i in range(1, len(s) // 2):
|
54 |
+
pos = i
|
55 |
+
if s[0 - 2 * pos:0 - pos] == s[0 - i:]:
|
56 |
+
tmp_repeat_pattern = s[0 - i:]
|
57 |
+
tmp_repeat_times = 1
|
58 |
+
while pos * (tmp_repeat_times + 2) <= len(s) \
|
59 |
+
and s[0 - pos * (tmp_repeat_times + 2):0 - pos * (tmp_repeat_times + 1)] == s[0 - i:]:
|
60 |
+
tmp_repeat_times += 1
|
61 |
+
if tmp_repeat_times * len(tmp_repeat_pattern) > repeat_times * len(repeat_pattern):
|
62 |
+
repeat_times = tmp_repeat_times
|
63 |
+
repeat_pattern = tmp_repeat_pattern
|
64 |
+
# print(repeat_pattern, '*', repeat_times)
|
65 |
+
if len(repeat_pattern) != 1:
|
66 |
+
s = s[:0 - repeat_times * len(repeat_pattern)] if repeat_times > 0 else s
|
67 |
+
# print(s)
|
68 |
+
return s
|
69 |
+
|
70 |
+
# Converts the unicode file to ascii
|
71 |
+
def unicode_to_ascii(s):
|
72 |
+
return ''.join(c for c in unicodedata.normalize('NFD', s)
|
73 |
+
if unicodedata.category(c) != 'Mn')
|
74 |
+
|
75 |
+
def log10(x):
|
76 |
+
numerator = tf.math.log(x)
|
77 |
+
denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
|
78 |
+
return numerator / denominator
|
vocab/__init__.py
ADDED
File without changes
|
vocab/vocab.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
class Vocab:
|
3 |
+
def __init__(self, file_path):
|
4 |
+
self.token_list = []
|
5 |
+
self.load_vocab_from_file(file_path)
|
6 |
+
|
7 |
+
def load_vocab_from_file(self, file_path):
|
8 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
9 |
+
vocabs = f.readlines()
|
10 |
+
for vocab in vocabs:
|
11 |
+
self.token_list.append(vocab.strip('\n'))
|
vocab/vocab.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
chinese
|
2 |
+
english
|
3 |
+
french
|
4 |
+
german
|
5 |
+
indonesian
|
6 |
+
italian
|
7 |
+
japanese
|
8 |
+
korean
|
9 |
+
portuguese
|
10 |
+
russian
|
11 |
+
spanish
|
12 |
+
turkish
|
13 |
+
vietnamese
|
14 |
+
other
|