Spaces:
Runtime error
Runtime error
jonathanjordan21
commited on
Commit
•
c021d8e
1
Parent(s):
48f5453
67809715652a92b22870c50ad30f6ff38e292006aedc75ddbdc828aa856ef68f
Browse files- LICENSE +21 -0
- README.md +71 -12
- assets/cover.png +0 -0
- assets/mfcc_stats.pkl +3 -0
- assets/spk2emb_82.pkl +3 -0
- assets/test_vctk.meta +0 -0
- data_loader.py +228 -0
- demo.ipynb +115 -0
- fast_decoders.py +75 -0
- hparams_autopst.py +62 -0
- hparams_sea.py +21 -0
- main_1.py +41 -0
- main_2.py +41 -0
- model_autopst.py +258 -0
- model_sea.py +242 -0
- onmt_modules/__init__.py +0 -0
- onmt_modules/average_attn.py +111 -0
- onmt_modules/decoder.py +25 -0
- onmt_modules/decoder_transformer.py +358 -0
- onmt_modules/embeddings.py +52 -0
- onmt_modules/encoder.py +58 -0
- onmt_modules/encoder_transformer.py +135 -0
- onmt_modules/misc.py +173 -0
- onmt_modules/multi_headed_attn.py +230 -0
- onmt_modules/position_ffn.py +41 -0
- override_decoder.py +63 -0
- prepare_train_data.py +90 -0
- solver_1.py +146 -0
- solver_2.py +160 -0
- tfcompat/__init__.py +0 -0
- tfcompat/hparam.py +726 -0
- tfcompat/readme.md +8 -0
- utils.py +111 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2019
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,71 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Global Prosody Style Transfer Without Text Transcriptions
|
2 |
+
|
3 |
+
This repository provides a PyTorch implementation of [AutoPST](https://arxiv.org/abs/2106.08519), which enables unsupervised global prosody conversion without text transcriptions.
|
4 |
+
|
5 |
+
This is a short video that explains the main concepts of our work. If you find this work useful and use it in your research, please consider citing our paper.
|
6 |
+
|
7 |
+
[![SpeechSplit](./assets/cover.png)](https://youtu.be/wow2DRuJ69c/)
|
8 |
+
|
9 |
+
```
|
10 |
+
@InProceedings{pmlr-v139-qian21b,
|
11 |
+
title = {Global Prosody Style Transfer Without Text Transcriptions},
|
12 |
+
author = {Qian, Kaizhi and Zhang, Yang and Chang, Shiyu and Xiong, Jinjun and Gan, Chuang and Cox, David and Hasegawa-Johnson, Mark},
|
13 |
+
booktitle = {Proceedings of the 38th International Conference on Machine Learning},
|
14 |
+
pages = {8650--8660},
|
15 |
+
year = {2021},
|
16 |
+
editor = {Meila, Marina and Zhang, Tong},
|
17 |
+
volume = {139},
|
18 |
+
series = {Proceedings of Machine Learning Research},
|
19 |
+
month = {18--24 Jul},
|
20 |
+
publisher = {PMLR},
|
21 |
+
url = {http://proceedings.mlr.press/v139/qian21b.html}
|
22 |
+
}
|
23 |
+
|
24 |
+
```
|
25 |
+
|
26 |
+
|
27 |
+
## Audio Demo
|
28 |
+
|
29 |
+
The audio demo for AutoPST can be found [here](https://auspicious3000.github.io/AutoPST-Demo/)
|
30 |
+
|
31 |
+
## Dependencies
|
32 |
+
- Python 3.6
|
33 |
+
- Numpy
|
34 |
+
- Scipy
|
35 |
+
- PyTorch == v1.6.0
|
36 |
+
- librosa
|
37 |
+
- pysptk
|
38 |
+
- soundfile
|
39 |
+
- wavenet_vocoder ```pip install wavenet_vocoder==0.1.1```
|
40 |
+
for more information, please refer to https://github.com/r9y9/wavenet_vocoder
|
41 |
+
|
42 |
+
|
43 |
+
## To Run Demo
|
44 |
+
|
45 |
+
Download [pre-trained models](https://drive.google.com/file/d/1ji3Bk6YGvXkPqFu1hLOAJp_SKw-vHGrp/view?usp=sharing) to ```assets```
|
46 |
+
|
47 |
+
Download the same WaveNet vocoder model as in [AutoVC](https://github.com/auspicious3000/autovc) to ```assets```
|
48 |
+
|
49 |
+
The fast and high-quality hifi-gan v1 (https://github.com/jik876/hifi-gan) pre-trained model is now available [here.](https://drive.google.com/file/d/1n76jHs8k1sDQ3Eh5ajXwdxuY_EZw4N9N/view?usp=sharing)
|
50 |
+
|
51 |
+
Please refer to [AutoVC](https://github.com/auspicious3000/autovc) if you have any problems with the vocoder part, because they share the same vocoder scripts.
|
52 |
+
|
53 |
+
Run ```demo.ipynb```
|
54 |
+
|
55 |
+
|
56 |
+
## To Train
|
57 |
+
|
58 |
+
Download [training data](https://drive.google.com/file/d/1H1dyA80qREKLHybqnYaqBRRsacIdFbnE/view?usp=sharing) to ```assets```.
|
59 |
+
The provided training data is very small for code verification purpose only.
|
60 |
+
Please use the scripts to prepare your own data for training.
|
61 |
+
|
62 |
+
1. Prepare training data: ```python prepare_train_data.py```
|
63 |
+
|
64 |
+
2. Train 1st Stage: ```python main_1.py```
|
65 |
+
|
66 |
+
3. Train 2nd Stage: ```python main_2.py```
|
67 |
+
|
68 |
+
|
69 |
+
## Final Words
|
70 |
+
|
71 |
+
This project is part of an ongoing research. We hope this repo is useful for your research. If you need any help or have any suggestions on improving the framework, please raise an issue and we will do our best to get back to you as soon as possible.
|
assets/cover.png
ADDED
assets/mfcc_stats.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e34f5e00e6eb8591e8dcc3796a56a048c8512245ca552c83069a4c8eb3a57387
|
3 |
+
size 52719
|
assets/spk2emb_82.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ea97c16a4d1d2eca10a0481630daa8590d9d8a1e27a87e53fa4340c17440745a
|
3 |
+
size 32133
|
assets/test_vctk.meta
ADDED
Binary file (200 kB). View file
|
|
data_loader.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from numpy.random import uniform
|
7 |
+
from torch.utils import data
|
8 |
+
from torch.utils.data.sampler import Sampler
|
9 |
+
from multiprocessing import Process, Manager
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
class Utterances(data.Dataset):
|
14 |
+
"""Dataset class for the Utterances dataset."""
|
15 |
+
|
16 |
+
def __init__(self, hparams):
|
17 |
+
"""Initialize and preprocess the Utterances dataset."""
|
18 |
+
self.meta_file = hparams.meta_file
|
19 |
+
|
20 |
+
self.feat_dir_1 = hparams.feat_dir_1
|
21 |
+
self.feat_dir_2 = hparams.feat_dir_2
|
22 |
+
self.feat_dir_3 = hparams.feat_dir_3
|
23 |
+
|
24 |
+
self.step = 4
|
25 |
+
self.split = 0
|
26 |
+
|
27 |
+
self.max_len_pad = hparams.max_len_pad
|
28 |
+
|
29 |
+
meta = pickle.load(open(self.meta_file, "rb"))
|
30 |
+
|
31 |
+
manager = Manager()
|
32 |
+
meta = manager.list(meta)
|
33 |
+
dataset = manager.list(len(meta)*[None]) # <-- can be shared between processes.
|
34 |
+
processes = []
|
35 |
+
for i in range(0, len(meta), self.step):
|
36 |
+
p = Process(target=self.load_data,
|
37 |
+
args=(meta[i:i+self.step],dataset,i))
|
38 |
+
p.start()
|
39 |
+
processes.append(p)
|
40 |
+
for p in processes:
|
41 |
+
p.join()
|
42 |
+
|
43 |
+
# very importtant to do dataset = list(dataset)
|
44 |
+
self.train_dataset = list(dataset)
|
45 |
+
self.num_tokens = len(self.train_dataset)
|
46 |
+
|
47 |
+
print('Finished loading the {} Utterances training dataset...'.format(self.num_tokens))
|
48 |
+
|
49 |
+
|
50 |
+
def load_data(self, submeta, dataset, idx_offset):
|
51 |
+
for k, sbmt in enumerate(submeta):
|
52 |
+
uttrs = len(sbmt)*[None]
|
53 |
+
for j, tmp in enumerate(sbmt):
|
54 |
+
if j < 2:
|
55 |
+
# fill in speaker name and embedding
|
56 |
+
uttrs[j] = tmp
|
57 |
+
else:
|
58 |
+
# fill in data
|
59 |
+
sp_tmp = np.load(os.path.join(self.feat_dir_1, tmp))
|
60 |
+
cep_tmp = np.load(os.path.join(self.feat_dir_2, tmp))[:, 0:14]
|
61 |
+
cd_tmp = np.load(os.path.join(self.feat_dir_3, tmp))
|
62 |
+
|
63 |
+
assert len(sp_tmp) == len(cep_tmp) == len(cd_tmp)
|
64 |
+
|
65 |
+
uttrs[j] = ( np.clip(sp_tmp, 0, 1), cep_tmp, cd_tmp )
|
66 |
+
dataset[idx_offset+k] = uttrs
|
67 |
+
|
68 |
+
|
69 |
+
def segment_np(self, cd_long, tau=2):
|
70 |
+
|
71 |
+
cd_norm = np.sqrt((cd_long ** 2).sum(axis=-1, keepdims=True))
|
72 |
+
G = (cd_long @ cd_long.T) / (cd_norm @ cd_norm.T)
|
73 |
+
|
74 |
+
L = G.shape[0]
|
75 |
+
|
76 |
+
num_rep = []
|
77 |
+
num_rep_sync = []
|
78 |
+
|
79 |
+
prev_boundary = 0
|
80 |
+
rate = np.random.uniform(0.8, 1.3)
|
81 |
+
|
82 |
+
for t in range(1, L+1):
|
83 |
+
if t==L:
|
84 |
+
num_rep.append(t - prev_boundary)
|
85 |
+
num_rep_sync.append(t - prev_boundary)
|
86 |
+
prev_boundary = t
|
87 |
+
if t < L:
|
88 |
+
q = np.random.uniform(rate-0.1, rate)
|
89 |
+
tmp = G[prev_boundary, max(prev_boundary-20, 0):min(prev_boundary+20, L)]
|
90 |
+
if q <= 1:
|
91 |
+
epsilon = np.quantile(tmp, q)
|
92 |
+
if np.all(G[prev_boundary, t:min(t+tau, L)] < epsilon):
|
93 |
+
num_rep.append(t - prev_boundary)
|
94 |
+
num_rep_sync.append(t - prev_boundary)
|
95 |
+
prev_boundary = t
|
96 |
+
else:
|
97 |
+
epsilon = np.quantile(tmp, 2-q)
|
98 |
+
if np.all(G[prev_boundary, t:min(t+tau, L)] < epsilon):
|
99 |
+
num_rep.append(t - prev_boundary)
|
100 |
+
else:
|
101 |
+
num_rep.extend([t-prev_boundary-0.5, 0.5])
|
102 |
+
|
103 |
+
num_rep_sync.append(t - prev_boundary)
|
104 |
+
prev_boundary = t
|
105 |
+
|
106 |
+
num_rep = np.array(num_rep)
|
107 |
+
num_rep_sync = np.array(num_rep_sync)
|
108 |
+
|
109 |
+
return num_rep, num_rep_sync
|
110 |
+
|
111 |
+
|
112 |
+
def __getitem__(self, index):
|
113 |
+
"""Return M uttrs for one spkr."""
|
114 |
+
dataset = self.train_dataset
|
115 |
+
|
116 |
+
list_uttrs = dataset[index]
|
117 |
+
|
118 |
+
emb_org = list_uttrs[1]
|
119 |
+
|
120 |
+
uttr = np.random.randint(2, len(list_uttrs))
|
121 |
+
melsp, melcep, cd_real = list_uttrs[uttr]
|
122 |
+
|
123 |
+
num_rep, num_rep_sync = self.segment_np(cd_real)
|
124 |
+
|
125 |
+
return melsp, melcep, cd_real, num_rep, num_rep_sync, len(melsp), len(num_rep), len(num_rep_sync), emb_org
|
126 |
+
|
127 |
+
|
128 |
+
def __len__(self):
|
129 |
+
"""Return the number of spkrs."""
|
130 |
+
return self.num_tokens
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
class MyCollator(object):
|
135 |
+
def __init__(self, hparams):
|
136 |
+
self.max_len_pad = hparams.max_len_pad
|
137 |
+
|
138 |
+
def __call__(self, batch):
|
139 |
+
new_batch = []
|
140 |
+
|
141 |
+
l_short_max = 0
|
142 |
+
l_short_sync_max = 0
|
143 |
+
l_real_max = 0
|
144 |
+
|
145 |
+
for token in batch:
|
146 |
+
sp_real, cep_real, cd_real, rep, rep_sync, l_real, l_short, l_short_sync, emb = token
|
147 |
+
|
148 |
+
if l_short > l_short_max:
|
149 |
+
l_short_max = l_short
|
150 |
+
|
151 |
+
if l_short_sync > l_short_sync_max:
|
152 |
+
l_short_sync_max = l_short_sync
|
153 |
+
|
154 |
+
if l_real > l_real_max:
|
155 |
+
l_real_max = l_real
|
156 |
+
|
157 |
+
sp_real_pad = np.pad(sp_real, ((0,self.max_len_pad-l_real),(0,0)), 'constant')
|
158 |
+
cep_real_pad = np.pad(cep_real, ((0,self.max_len_pad-l_real),(0,0)), 'constant')
|
159 |
+
cd_real_pad = np.pad(cd_real, ((0,self.max_len_pad-l_real),(0,0)), 'constant')
|
160 |
+
|
161 |
+
rep_pad = np.pad(rep, (0,self.max_len_pad-l_short), 'constant')
|
162 |
+
rep_sync_pad = np.pad(rep_sync, (0,self.max_len_pad-l_short_sync), 'constant')
|
163 |
+
|
164 |
+
new_batch.append( (sp_real_pad, cep_real_pad, cd_real_pad, rep_pad, rep_sync_pad, l_real, l_short, l_short_sync, emb) )
|
165 |
+
|
166 |
+
batch = new_batch
|
167 |
+
|
168 |
+
a, b, c, d, e, f, g, h, i = zip(*batch)
|
169 |
+
|
170 |
+
sp_real = torch.from_numpy(np.stack(a, axis=0))[:,:l_real_max+1,:]
|
171 |
+
cep_real = torch.from_numpy(np.stack(b, axis=0))[:,:l_real_max+1,:]
|
172 |
+
cd_real = torch.from_numpy(np.stack(c, axis=0))[:,:l_real_max+1,:]
|
173 |
+
num_rep = torch.from_numpy(np.stack(d, axis=0))[:,:l_short_max+1]
|
174 |
+
num_rep_sync = torch.from_numpy(np.stack(e, axis=0))[:,:l_short_sync_max+1]
|
175 |
+
|
176 |
+
len_real = torch.from_numpy(np.stack(f, axis=0))
|
177 |
+
len_short = torch.from_numpy(np.stack(g, axis=0))
|
178 |
+
len_short_sync = torch.from_numpy(np.stack(h, axis=0))
|
179 |
+
|
180 |
+
spk_emb = torch.from_numpy(np.stack(i, axis=0))
|
181 |
+
|
182 |
+
return sp_real, cep_real, cd_real, num_rep, num_rep_sync, len_real, len_short, len_short_sync, spk_emb
|
183 |
+
|
184 |
+
|
185 |
+
|
186 |
+
class MultiSampler(Sampler):
|
187 |
+
"""Samples elements more than once in a single pass through the data.
|
188 |
+
"""
|
189 |
+
def __init__(self, num_samples, n_repeats, shuffle=False):
|
190 |
+
self.num_samples = num_samples
|
191 |
+
self.n_repeats = n_repeats
|
192 |
+
self.shuffle = shuffle
|
193 |
+
|
194 |
+
def gen_sample_array(self):
|
195 |
+
self.sample_idx_array = torch.arange(self.num_samples, dtype=torch.int64).repeat(self.n_repeats)
|
196 |
+
if self.shuffle:
|
197 |
+
self.sample_idx_array = self.sample_idx_array[torch.randperm(len(self.sample_idx_array))]
|
198 |
+
return self.sample_idx_array
|
199 |
+
|
200 |
+
def __iter__(self):
|
201 |
+
return iter(self.gen_sample_array())
|
202 |
+
|
203 |
+
def __len__(self):
|
204 |
+
return len(self.sample_idx_array)
|
205 |
+
|
206 |
+
|
207 |
+
|
208 |
+
def worker_init_fn(x):
|
209 |
+
return np.random.seed((torch.initial_seed()) % (2**32))
|
210 |
+
|
211 |
+
def get_loader(hparams):
|
212 |
+
"""Build and return a data loader."""
|
213 |
+
|
214 |
+
dataset = Utterances(hparams)
|
215 |
+
|
216 |
+
my_collator = MyCollator(hparams)
|
217 |
+
|
218 |
+
sampler = MultiSampler(len(dataset), hparams.samplier, shuffle=hparams.shuffle)
|
219 |
+
|
220 |
+
data_loader = data.DataLoader(dataset=dataset,
|
221 |
+
batch_size=hparams.batch_size,
|
222 |
+
sampler=sampler,
|
223 |
+
num_workers=hparams.num_workers,
|
224 |
+
drop_last=True,
|
225 |
+
pin_memory=False,
|
226 |
+
worker_init_fn=worker_init_fn,
|
227 |
+
collate_fn=my_collator)
|
228 |
+
return data_loader
|
demo.ipynb
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import os\n",
|
10 |
+
"import pickle\n",
|
11 |
+
"import numpy as np\n",
|
12 |
+
"import torch\n",
|
13 |
+
"import torch.nn.functional as F\n",
|
14 |
+
"from collections import OrderedDict\n",
|
15 |
+
"from onmt_modules.misc import sequence_mask\n",
|
16 |
+
"from model_autopst import Generator_2 as Predictor\n",
|
17 |
+
"from hparams_autopst import hparams\n",
|
18 |
+
"\n",
|
19 |
+
"device = 'cuda:0'\n",
|
20 |
+
"\n",
|
21 |
+
"P = Predictor(hparams).eval().to(device)\n",
|
22 |
+
"\n",
|
23 |
+
"checkpoint = torch.load('./assets/580000-P.ckpt', map_location=lambda storage, loc: storage) \n",
|
24 |
+
"P.load_state_dict(checkpoint['model'], strict=True)\n",
|
25 |
+
"print('Loaded predictor .....................................................')\n",
|
26 |
+
"\n",
|
27 |
+
"dict_test = pickle.load(open('./assets/test_vctk.meta', 'rb'))"
|
28 |
+
]
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"cell_type": "code",
|
32 |
+
"execution_count": null,
|
33 |
+
"metadata": {},
|
34 |
+
"outputs": [],
|
35 |
+
"source": [
|
36 |
+
"spect_vc = OrderedDict()\n",
|
37 |
+
"\n",
|
38 |
+
"uttrs = [('p231', 'p270', '001'),\n",
|
39 |
+
" ('p270', 'p231', '001'),\n",
|
40 |
+
" ('p231', 'p245', '003001'),\n",
|
41 |
+
" ('p245', 'p231', '003001'),\n",
|
42 |
+
" ('p239', 'p270', '024002'),\n",
|
43 |
+
" ('p270', 'p239', '024002')]\n",
|
44 |
+
"\n",
|
45 |
+
"\n",
|
46 |
+
"for uttr in uttrs:\n",
|
47 |
+
" \n",
|
48 |
+
" cep_real, spk_emb = dict_test[uttr[0]][uttr[2]]\n",
|
49 |
+
" cep_real_A = torch.from_numpy(cep_real).unsqueeze(0).to(device)\n",
|
50 |
+
" len_real_A = torch.tensor(cep_real_A.size(1)).unsqueeze(0).to(device)\n",
|
51 |
+
" real_mask_A = sequence_mask(len_real_A, cep_real_A.size(1)).float()\n",
|
52 |
+
" \n",
|
53 |
+
" _, spk_emb = dict_test[uttr[1]][uttr[2]]\n",
|
54 |
+
" spk_emb_B = torch.from_numpy(spk_emb).unsqueeze(0).to(device)\n",
|
55 |
+
" \n",
|
56 |
+
" with torch.no_grad():\n",
|
57 |
+
" spect_output, len_spect = P.infer_onmt(cep_real_A.transpose(2,1)[:,:14,:],\n",
|
58 |
+
" real_mask_A,\n",
|
59 |
+
" len_real_A,\n",
|
60 |
+
" spk_emb_B)\n",
|
61 |
+
" \n",
|
62 |
+
" uttr_tgt = spect_output[:len_spect[0],0,:].cpu().numpy()\n",
|
63 |
+
" \n",
|
64 |
+
" spect_vc[f'{uttr[0]}_{uttr[1]}_{uttr[2]}'] = uttr_tgt"
|
65 |
+
]
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"cell_type": "code",
|
69 |
+
"execution_count": null,
|
70 |
+
"metadata": {},
|
71 |
+
"outputs": [],
|
72 |
+
"source": [
|
73 |
+
"# spectrogram to waveform\n",
|
74 |
+
"# Feel free to use other vocoders\n",
|
75 |
+
"# This cell requires some preparation to work, please see the corresponding part in AutoVC\n",
|
76 |
+
"import torch\n",
|
77 |
+
"import librosa\n",
|
78 |
+
"import pickle\n",
|
79 |
+
"import os\n",
|
80 |
+
"from synthesis import build_model\n",
|
81 |
+
"from synthesis import wavegen\n",
|
82 |
+
"\n",
|
83 |
+
"model = build_model().to(device)\n",
|
84 |
+
"checkpoint = torch.load(\"./assets/checkpoint_step001000000_ema.pth\")\n",
|
85 |
+
"model.load_state_dict(checkpoint[\"state_dict\"])\n",
|
86 |
+
"\n",
|
87 |
+
"for name, sp in spect_vc.items():\n",
|
88 |
+
" print(name)\n",
|
89 |
+
" waveform = wavegen(model, c=sp) \n",
|
90 |
+
" librosa.output.write_wav('./assets/'+name+'.wav', waveform, sr=16000)"
|
91 |
+
]
|
92 |
+
}
|
93 |
+
],
|
94 |
+
"metadata": {
|
95 |
+
"kernelspec": {
|
96 |
+
"display_name": "Python 3",
|
97 |
+
"language": "python",
|
98 |
+
"name": "python3"
|
99 |
+
},
|
100 |
+
"language_info": {
|
101 |
+
"codemirror_mode": {
|
102 |
+
"name": "ipython",
|
103 |
+
"version": 3
|
104 |
+
},
|
105 |
+
"file_extension": ".py",
|
106 |
+
"mimetype": "text/x-python",
|
107 |
+
"name": "python",
|
108 |
+
"nbconvert_exporter": "python",
|
109 |
+
"pygments_lexer": "ipython3",
|
110 |
+
"version": "3.7.5"
|
111 |
+
}
|
112 |
+
},
|
113 |
+
"nbformat": 4,
|
114 |
+
"nbformat_minor": 4
|
115 |
+
}
|
fast_decoders.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from onmt_modules.misc import sequence_mask
|
6 |
+
|
7 |
+
|
8 |
+
class DecodeFunc_Sp(object):
|
9 |
+
"""
|
10 |
+
Decoding functions
|
11 |
+
"""
|
12 |
+
def __init__(self, hparams, type_out):
|
13 |
+
|
14 |
+
if type_out == 'Sp':
|
15 |
+
self.dim_freq = hparams.dim_freq
|
16 |
+
self.max_decoder_steps = hparams.dec_steps_sp
|
17 |
+
elif type_out == 'Tx':
|
18 |
+
self.dim_freq = hparams.dim_code
|
19 |
+
self.max_decoder_steps = hparams.dec_steps_tx
|
20 |
+
else:
|
21 |
+
raise ValueError
|
22 |
+
|
23 |
+
self.gate_threshold = hparams.gate_threshold
|
24 |
+
self.type_out = type_out
|
25 |
+
|
26 |
+
def __call__(self, tgt, memory_bank, memory_lengths, decoder, postnet):
|
27 |
+
|
28 |
+
dec_outs, attns = decoder(tgt, memory_bank, step=None,
|
29 |
+
memory_lengths=memory_lengths)
|
30 |
+
spect_gate = postnet(dec_outs)
|
31 |
+
spect, gate = spect_gate[:, :, 1:], spect_gate[:, :, :1]
|
32 |
+
|
33 |
+
return spect, gate
|
34 |
+
|
35 |
+
|
36 |
+
def infer(self, tgt_real, memory_bank, memory_lengths, decoder, postnet):
|
37 |
+
B = memory_bank.size(1)
|
38 |
+
device = memory_bank.device
|
39 |
+
|
40 |
+
spect_outputs = torch.zeros((self.max_decoder_steps, B, self.dim_freq),
|
41 |
+
dtype=torch.float, device=device)
|
42 |
+
gate_outputs = torch.zeros((self.max_decoder_steps, B, 1),
|
43 |
+
dtype=torch.float, device=device)
|
44 |
+
tgt_words = torch.zeros([B, 1],
|
45 |
+
dtype=torch.float, device=device)
|
46 |
+
|
47 |
+
current_pred = torch.zeros([1, B, self.dim_freq],
|
48 |
+
dtype=torch.float, device=device)
|
49 |
+
|
50 |
+
for t in range(self.max_decoder_steps):
|
51 |
+
|
52 |
+
dec_outs, _ = decoder(current_pred,
|
53 |
+
memory_bank, t,
|
54 |
+
memory_lengths=memory_lengths,
|
55 |
+
tgt_words=tgt_words)
|
56 |
+
spect_gate = postnet(dec_outs)
|
57 |
+
|
58 |
+
spect, gate = spect_gate[:, :, 1:], spect_gate[:, :, :1]
|
59 |
+
spect_outputs[t:t+1] = spect
|
60 |
+
gate_outputs[t:t+1] = gate
|
61 |
+
|
62 |
+
stop = (torch.sigmoid(gate) - self.gate_threshold + 0.5).round()
|
63 |
+
current_pred = spect.data
|
64 |
+
tgt_words = stop.squeeze(-1).t()
|
65 |
+
|
66 |
+
if t == self.max_decoder_steps - 1:
|
67 |
+
print(f"Warning! {self.type_out} reached max decoder steps")
|
68 |
+
|
69 |
+
if (stop == 1).all():
|
70 |
+
break
|
71 |
+
|
72 |
+
stop_quant = (torch.sigmoid(gate_outputs.data) - self.gate_threshold + 0.5).round().squeeze(-1)
|
73 |
+
len_spect = (stop_quant.cumsum(dim=0)==0).sum(dim=0)
|
74 |
+
|
75 |
+
return spect_outputs, len_spect, gate_outputs
|
hparams_autopst.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tfcompat.hparam import HParams
|
2 |
+
|
3 |
+
# NOTE: If you want full control for model architecture. please take a look
|
4 |
+
# at the code and change whatever you want. Some hyper parameters are hardcoded.
|
5 |
+
|
6 |
+
# Default hyperparameters:
|
7 |
+
hparams = HParams(
|
8 |
+
|
9 |
+
# sea params
|
10 |
+
dim_neck_sea = 4,
|
11 |
+
dim_freq_sea = 14,
|
12 |
+
dim_enc_sea = 512,
|
13 |
+
|
14 |
+
# autopst params
|
15 |
+
dim_freq = 80,
|
16 |
+
dim_code = 4,
|
17 |
+
dim_spk = 82,
|
18 |
+
dim_sty = 128,
|
19 |
+
gate_threshold = 0.48,
|
20 |
+
dec_steps_tx = 640,
|
21 |
+
dec_steps_sp = 806,
|
22 |
+
chs_grp = 16,
|
23 |
+
|
24 |
+
# onmt params
|
25 |
+
enc_layers = 4,
|
26 |
+
enc_rnn_size = 256,
|
27 |
+
dec_layers = 4,
|
28 |
+
dec_rnn_size = 256,
|
29 |
+
transformer_ff = 1024,
|
30 |
+
heads = 8,
|
31 |
+
dropout = 0.1,
|
32 |
+
attention_dropout = 0.1,
|
33 |
+
max_relative_positions = 0,
|
34 |
+
copy_attn = False,
|
35 |
+
self_attn_type = "scaled-dot",
|
36 |
+
aan_useffn = False,
|
37 |
+
full_context_alignment = False,
|
38 |
+
alignment_layer = 0,
|
39 |
+
alignment_heads = 0,
|
40 |
+
|
41 |
+
# pretrained model
|
42 |
+
pretrained_path = './assets/xxx.ckpt',
|
43 |
+
|
44 |
+
# data loader
|
45 |
+
meta_file = './assets/train_vctk.meta',
|
46 |
+
feat_dir_1 = './assets/vctk16-train-sp-mel',
|
47 |
+
feat_dir_2 = './assets/vctk16-train-cep-mel',
|
48 |
+
feat_dir_3 = './assets/vctk16-train-teacher',
|
49 |
+
batch_size = 4,
|
50 |
+
shuffle = True,
|
51 |
+
num_workers = 0,
|
52 |
+
samplier = 2,
|
53 |
+
max_len_pad = 2048,
|
54 |
+
sampling_params = (0.8, 1.3, 0.1),
|
55 |
+
|
56 |
+
)
|
57 |
+
|
58 |
+
|
59 |
+
def hparams_debug_string():
|
60 |
+
values = hparams.values()
|
61 |
+
hp = [' %s: %s' % (name, values[name]) for name in values]
|
62 |
+
return 'Hyperparameters:\n' + '\n'.join(hp)
|
hparams_sea.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tfcompat.hparam import HParams
|
2 |
+
|
3 |
+
# NOTE: If you want full control for model architecture. please take a look
|
4 |
+
# at the code and change whatever you want. Some hyper parameters are hardcoded.
|
5 |
+
|
6 |
+
# Default hyperparameters:
|
7 |
+
hparams = HParams(
|
8 |
+
dim_neck_sea = 8,
|
9 |
+
dim_freq_sea = 20,
|
10 |
+
dim_spk = 82,
|
11 |
+
dim_enc_sea = 512,
|
12 |
+
chs_grp = 16,
|
13 |
+
dim_freq_sp = 80,
|
14 |
+
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
def hparams_debug_string():
|
19 |
+
values = hparams.values()
|
20 |
+
hp = [' %s: %s' % (name, values[name]) for name in values]
|
21 |
+
return 'Hyperparameters:\n' + '\n'.join(hp)
|
main_1.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from solver_1 import Solver
|
6 |
+
from data_loader import get_loader
|
7 |
+
from hparams_autopst import hparams, hparams_debug_string
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
def str2bool(v):
|
12 |
+
return v.lower() in ('true')
|
13 |
+
|
14 |
+
def main(config):
|
15 |
+
|
16 |
+
# Data loader
|
17 |
+
data_loader = get_loader(hparams)
|
18 |
+
|
19 |
+
# Solver for training
|
20 |
+
solver = Solver(data_loader, config, hparams)
|
21 |
+
|
22 |
+
solver.train()
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
if __name__ == '__main__':
|
27 |
+
parser = argparse.ArgumentParser()
|
28 |
+
|
29 |
+
# Training configuration.
|
30 |
+
parser.add_argument('--num_iters', type=int, default=1000000)
|
31 |
+
|
32 |
+
# Miscellaneous.
|
33 |
+
parser.add_argument('--device_id', type=int, default=0)
|
34 |
+
|
35 |
+
# Step size.
|
36 |
+
parser.add_argument('--log_step', type=int, default=10)
|
37 |
+
|
38 |
+
config = parser.parse_args()
|
39 |
+
print(config)
|
40 |
+
print(hparams_debug_string())
|
41 |
+
main(config)
|
main_2.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from solver_2 import Solver
|
6 |
+
from data_loader import get_loader
|
7 |
+
from hparams_autopst import hparams, hparams_debug_string
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
def str2bool(v):
|
12 |
+
return v.lower() in ('true')
|
13 |
+
|
14 |
+
def main(config):
|
15 |
+
|
16 |
+
# Data loader
|
17 |
+
data_loader = get_loader(hparams)
|
18 |
+
|
19 |
+
# Solver for training
|
20 |
+
solver = Solver(data_loader, config, hparams)
|
21 |
+
|
22 |
+
solver.train()
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
if __name__ == '__main__':
|
27 |
+
parser = argparse.ArgumentParser()
|
28 |
+
|
29 |
+
# Training configuration.
|
30 |
+
parser.add_argument('--num_iters', type=int, default=1000000)
|
31 |
+
|
32 |
+
# Miscellaneous.
|
33 |
+
parser.add_argument('--device_id', type=int, default=0)
|
34 |
+
|
35 |
+
# Step size.
|
36 |
+
parser.add_argument('--log_step', type=int, default=10)
|
37 |
+
|
38 |
+
config = parser.parse_args()
|
39 |
+
print(config)
|
40 |
+
print(hparams_debug_string())
|
41 |
+
main(config)
|
model_autopst.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from utils import filter_bank_mean
|
7 |
+
|
8 |
+
from fast_decoders import DecodeFunc_Sp
|
9 |
+
|
10 |
+
from model_sea import Encoder_2 as Encoder_Code_2
|
11 |
+
|
12 |
+
from override_decoder import OnmtDecoder_1 as OnmtDecoder
|
13 |
+
|
14 |
+
from onmt_modules.misc import sequence_mask
|
15 |
+
from onmt_modules.embeddings import PositionalEncoding
|
16 |
+
from onmt_modules.encoder_transformer import TransformerEncoder as OnmtEncoder
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
class Prenet(nn.Module):
|
21 |
+
def __init__(self, dim_input, dim_output, dropout=0.1):
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
mlp = nn.Linear(dim_input, dim_output, bias=True)
|
25 |
+
pe = PositionalEncoding(dropout, dim_output, 1600)
|
26 |
+
|
27 |
+
self.make_prenet = nn.Sequential()
|
28 |
+
self.make_prenet.add_module('mlp', mlp)
|
29 |
+
self.make_prenet.add_module('pe', pe)
|
30 |
+
|
31 |
+
self.word_padding_idx = 1
|
32 |
+
|
33 |
+
def forward(self, source, step=None):
|
34 |
+
|
35 |
+
for i, module in enumerate(self.make_prenet._modules.values()):
|
36 |
+
if i == len(self.make_prenet._modules.values()) - 1:
|
37 |
+
source = module(source, step=step)
|
38 |
+
else:
|
39 |
+
source = module(source)
|
40 |
+
|
41 |
+
return source
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
class Decoder_Sp(nn.Module):
|
46 |
+
"""
|
47 |
+
Speech Decoder
|
48 |
+
"""
|
49 |
+
def __init__(self, hparams):
|
50 |
+
super().__init__()
|
51 |
+
|
52 |
+
self.dim_freq = hparams.dim_freq
|
53 |
+
self.max_decoder_steps = hparams.dec_steps_sp
|
54 |
+
self.gate_threshold = hparams.gate_threshold
|
55 |
+
|
56 |
+
prenet = Prenet(hparams.dim_freq, hparams.dec_rnn_size)
|
57 |
+
self.decoder = OnmtDecoder.from_opt(hparams, prenet)
|
58 |
+
|
59 |
+
self.postnet = nn.Linear(hparams.dec_rnn_size,
|
60 |
+
hparams.dim_freq+1, bias=True)
|
61 |
+
|
62 |
+
def forward(self, tgt, tgt_lengths, memory_bank, memory_lengths):
|
63 |
+
|
64 |
+
dec_outs, attns = self.decoder(tgt, memory_bank, step=None,
|
65 |
+
memory_lengths=memory_lengths,
|
66 |
+
tgt_lengths=tgt_lengths)
|
67 |
+
spect_gate = self.postnet(dec_outs)
|
68 |
+
spect, gate = spect_gate[:, :, 1:], spect_gate[:, :, :1]
|
69 |
+
|
70 |
+
return spect, gate
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
class Encoder_Tx_Spk(nn.Module):
|
75 |
+
"""
|
76 |
+
Text Encoder
|
77 |
+
"""
|
78 |
+
def __init__(self, hparams):
|
79 |
+
super().__init__()
|
80 |
+
|
81 |
+
prenet = Prenet(hparams.dim_code+hparams.dim_spk,
|
82 |
+
hparams.enc_rnn_size)
|
83 |
+
self.encoder = OnmtEncoder.from_opt(hparams, prenet)
|
84 |
+
|
85 |
+
def forward(self, src, src_lengths, spk_emb):
|
86 |
+
|
87 |
+
spk_emb = spk_emb.unsqueeze(0).expand(src.size(0),-1,-1)
|
88 |
+
src_spk = torch.cat((src, spk_emb), dim=-1)
|
89 |
+
enc_states, memory_bank, src_lengths = self.encoder(src_spk, src_lengths)
|
90 |
+
|
91 |
+
return enc_states, memory_bank, src_lengths
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
class Decoder_Tx(nn.Module):
|
96 |
+
"""
|
97 |
+
Text Decoder with stop
|
98 |
+
and num_rep prediction
|
99 |
+
"""
|
100 |
+
def __init__(self, hparams):
|
101 |
+
super().__init__()
|
102 |
+
|
103 |
+
self.dim_code = hparams.dim_code
|
104 |
+
self.max_decoder_steps = hparams.dec_steps_tx
|
105 |
+
self.gate_threshold = hparams.gate_threshold
|
106 |
+
self.dim_rep = hparams.dim_rep
|
107 |
+
|
108 |
+
prenet = Prenet(hparams.dim_code, hparams.dec_rnn_size)
|
109 |
+
self.decoder = OnmtDecoder.from_opt(hparams, prenet)
|
110 |
+
|
111 |
+
self.postnet_1 = nn.Linear(hparams.dec_rnn_size,
|
112 |
+
hparams.dim_code+1, bias=True)
|
113 |
+
|
114 |
+
self.postnet_2 = nn.Linear(hparams.dec_rnn_size,
|
115 |
+
self.dim_rep, bias=True)
|
116 |
+
|
117 |
+
def forward(self, tgt, tgt_lengths, memory_bank, memory_lengths):
|
118 |
+
|
119 |
+
dec_outs, attns = self.decoder(tgt, memory_bank, step=None,
|
120 |
+
memory_lengths=memory_lengths,
|
121 |
+
tgt_lengths=tgt_lengths)
|
122 |
+
gate_text = self.postnet_1(dec_outs)
|
123 |
+
rep = self.postnet_2(dec_outs)
|
124 |
+
gate, text = gate_text[:, :, :1], gate_text[:, :, 1:]
|
125 |
+
|
126 |
+
return text, gate, rep
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
class Generator_1(nn.Module):
|
131 |
+
'''
|
132 |
+
sync stage 1
|
133 |
+
'''
|
134 |
+
def __init__(self, hparams):
|
135 |
+
super().__init__()
|
136 |
+
|
137 |
+
self.encoder_cd = Encoder_Code_2(hparams)
|
138 |
+
self.encoder_tx = Encoder_Tx_Spk(hparams)
|
139 |
+
self.decoder_sp = Decoder_Sp(hparams)
|
140 |
+
self.encoder_spk = nn.Linear(hparams.dim_spk,
|
141 |
+
hparams.enc_rnn_size, bias=True)
|
142 |
+
self.fast_dec_sp = DecodeFunc_Sp(hparams, 'Sp')
|
143 |
+
|
144 |
+
|
145 |
+
def pad_sequences_rnn(self, cd_short, num_rep, len_long):
|
146 |
+
B, L, C = cd_short.size()
|
147 |
+
out_tensor = torch.zeros((B, len_long.max(), C), device=cd_short.device)
|
148 |
+
'''
|
149 |
+
len_long = len_spect + 1
|
150 |
+
'''
|
151 |
+
for i in range(B):
|
152 |
+
code_sync = cd_short[i].repeat_interleave(num_rep[i], dim=0)
|
153 |
+
out_tensor[i, :len_long[i]-1, :] = code_sync
|
154 |
+
|
155 |
+
return out_tensor
|
156 |
+
|
157 |
+
|
158 |
+
def forward(self, cep_in, mask_long, codes_mask, num_rep, len_short,
|
159 |
+
tgt_spect, len_spect,
|
160 |
+
spk_emb):
|
161 |
+
|
162 |
+
cd_long = self.encoder_cd(cep_in, mask_long)
|
163 |
+
fb = filter_bank_mean(num_rep, codes_mask, cd_long.size(1))
|
164 |
+
|
165 |
+
cd_short = torch.bmm(fb.detach(), cd_long)
|
166 |
+
|
167 |
+
cd_short_sync = self.pad_sequences_rnn(cd_short, num_rep, len_spect)
|
168 |
+
|
169 |
+
spk_emb_1 = self.encoder_spk(spk_emb)
|
170 |
+
|
171 |
+
# text to speech
|
172 |
+
_, memory_tx, _ = self.encoder_tx(cd_short_sync.transpose(1,0), len_spect,
|
173 |
+
spk_emb)
|
174 |
+
memory_tx_spk = torch.cat((spk_emb_1.unsqueeze(0), memory_tx), dim=0)
|
175 |
+
self.decoder_sp.decoder.init_state(memory_tx_spk, None, None)
|
176 |
+
spect_out, gate_sp_out \
|
177 |
+
= self.decoder_sp(tgt_spect, len_spect, memory_tx_spk, len_spect+1)
|
178 |
+
|
179 |
+
return spect_out, gate_sp_out
|
180 |
+
|
181 |
+
|
182 |
+
def infer_onmt(self, cep_in, mask_long,
|
183 |
+
len_spect,
|
184 |
+
spk_emb):
|
185 |
+
|
186 |
+
cd_long = self.encoder_cd(cep_in, mask_long)
|
187 |
+
|
188 |
+
spk_emb_1 = self.encoder_spk(spk_emb)
|
189 |
+
|
190 |
+
# text to speech
|
191 |
+
_, memory_tx, _ = self.encoder_tx(cd_long.transpose(1,0), len_spect,
|
192 |
+
spk_emb)
|
193 |
+
memory_tx_spk = torch.cat((spk_emb_1.unsqueeze(0), memory_tx), dim=0)
|
194 |
+
self.decoder_sp.decoder.init_state(memory_tx_spk, None, None)
|
195 |
+
spect_output, len_spect_out, stop_sp_output \
|
196 |
+
= self.fast_dec_sp.infer(None, memory_tx_spk, len_spect+1,
|
197 |
+
self.decoder_sp.decoder,
|
198 |
+
self.decoder_sp.postnet)
|
199 |
+
|
200 |
+
return spect_output, len_spect_out
|
201 |
+
|
202 |
+
|
203 |
+
|
204 |
+
class Generator_2(nn.Module):
|
205 |
+
'''
|
206 |
+
async stage 2
|
207 |
+
'''
|
208 |
+
def __init__(self, hparams):
|
209 |
+
super().__init__()
|
210 |
+
|
211 |
+
self.encoder_cd = Encoder_Code_2(hparams)
|
212 |
+
self.encoder_tx = Encoder_Tx_Spk(hparams)
|
213 |
+
self.decoder_sp = Decoder_Sp(hparams)
|
214 |
+
self.encoder_spk = nn.Linear(hparams.dim_spk,
|
215 |
+
hparams.enc_rnn_size, bias=True)
|
216 |
+
self.fast_dec_sp = DecodeFunc_Sp(hparams, 'Sp')
|
217 |
+
|
218 |
+
|
219 |
+
def forward(self, cep_in, mask_long, codes_mask, num_rep, len_short,
|
220 |
+
tgt_spect, len_spect,
|
221 |
+
spk_emb):
|
222 |
+
|
223 |
+
cd_long = self.encoder_cd(cep_in, mask_long)
|
224 |
+
fb = filter_bank_mean(num_rep, codes_mask, cd_long.size(1))
|
225 |
+
|
226 |
+
cd_short = torch.bmm(fb.detach(), cd_long.detach())
|
227 |
+
|
228 |
+
spk_emb_1 = self.encoder_spk(spk_emb)
|
229 |
+
|
230 |
+
# text to speech
|
231 |
+
_, memory_tx, _ = self.encoder_tx(cd_short.transpose(1,0), len_short,
|
232 |
+
spk_emb)
|
233 |
+
memory_tx_spk = torch.cat((spk_emb_1.unsqueeze(0), memory_tx), dim=0)
|
234 |
+
self.decoder_sp.decoder.init_state(memory_tx_spk, None, None)
|
235 |
+
spect_out, gate_sp_out \
|
236 |
+
= self.decoder_sp(tgt_spect, len_spect, memory_tx_spk, len_short+1)
|
237 |
+
|
238 |
+
return spect_out, gate_sp_out
|
239 |
+
|
240 |
+
|
241 |
+
def infer_onmt(self, cep_in, mask_long, len_spect,
|
242 |
+
spk_emb):
|
243 |
+
|
244 |
+
cd_long = self.encoder_cd(cep_in, mask_long)
|
245 |
+
|
246 |
+
spk_emb_1 = self.encoder_spk(spk_emb)
|
247 |
+
|
248 |
+
# text to speech
|
249 |
+
_, memory_tx, _ = self.encoder_tx(cd_long.transpose(1,0), len_spect,
|
250 |
+
spk_emb)
|
251 |
+
memory_tx_spk = torch.cat((spk_emb_1.unsqueeze(0), memory_tx), dim=0)
|
252 |
+
self.decoder_sp.decoder.init_state(memory_tx_spk, None, None)
|
253 |
+
spect_output, len_spect_out, stop_sp_output \
|
254 |
+
= self.fast_dec_sp.infer(None, memory_tx_spk, len_spect+1,
|
255 |
+
self.decoder_sp.decoder,
|
256 |
+
self.decoder_sp.postnet)
|
257 |
+
|
258 |
+
return spect_output, len_spect_out
|
model_sea.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from utils import ConvNorm, LinearNorm
|
7 |
+
from torch.nn.parameter import Parameter
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
class GroupNorm_Mask(nn.Module):
|
12 |
+
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
self.num_groups = num_groups
|
16 |
+
self.num_channels = num_channels
|
17 |
+
self.eps = eps
|
18 |
+
self.affine = affine
|
19 |
+
if self.affine:
|
20 |
+
self.weight = Parameter(torch.Tensor(num_channels))
|
21 |
+
self.bias = Parameter(torch.Tensor(num_channels))
|
22 |
+
else:
|
23 |
+
self.register_parameter('weight', None)
|
24 |
+
self.register_parameter('bias', None)
|
25 |
+
self.reset_parameters()
|
26 |
+
|
27 |
+
def reset_parameters(self):
|
28 |
+
if self.affine:
|
29 |
+
nn.init.ones_(self.weight)
|
30 |
+
nn.init.zeros_(self.bias)
|
31 |
+
|
32 |
+
def forward(self, x, mask):
|
33 |
+
B, C, L = x.size()
|
34 |
+
assert C % self.num_groups == 0
|
35 |
+
|
36 |
+
x = x.view(B, self.num_groups, C//self.num_groups, L)
|
37 |
+
mask = mask.view(B, 1, 1, L)
|
38 |
+
x = x * mask
|
39 |
+
|
40 |
+
mean = x.mean(dim=2, keepdim=True).sum(dim=3, keepdim=True) / mask.sum(dim=3, keepdim=True)
|
41 |
+
var = (((x - mean)**2)*mask).mean(dim=2, keepdim=True).sum(dim=3, keepdim=True) / mask.sum(dim=3, keepdim=True)
|
42 |
+
|
43 |
+
x = (x - mean) / (var + self.eps).sqrt()
|
44 |
+
|
45 |
+
x = x.view(B, C, L)
|
46 |
+
|
47 |
+
return x * self.weight.view(1,-1,1) + self.bias.view(1,-1,1)
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
class M43_Sequential(nn.Sequential):
|
52 |
+
def forward(self, inputs, mask):
|
53 |
+
inputs = self._modules['0'](inputs)
|
54 |
+
inputs = self._modules['1'](inputs, mask)
|
55 |
+
return inputs
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
class Encoder(nn.Module):
|
60 |
+
"""Encoder module:
|
61 |
+
"""
|
62 |
+
def __init__(self, hparams):
|
63 |
+
super(Encoder, self).__init__()
|
64 |
+
|
65 |
+
self.dim_freq = hparams.dim_freq_sea
|
66 |
+
self.dim_enc = hparams.dim_enc_sea
|
67 |
+
self.chs_grp = hparams.chs_grp
|
68 |
+
self.dim_neck = hparams.dim_neck_sea
|
69 |
+
|
70 |
+
convolutions = []
|
71 |
+
for i in range(5):
|
72 |
+
conv_layer = M43_Sequential(
|
73 |
+
ConvNorm(self.dim_freq if i==0 else self.dim_enc,
|
74 |
+
self.dim_enc,
|
75 |
+
kernel_size=1, stride=1,
|
76 |
+
padding=0,
|
77 |
+
dilation=1, w_init_gain='relu'),
|
78 |
+
GroupNorm_Mask(self.dim_enc//self.chs_grp, self.dim_enc))
|
79 |
+
convolutions.append(conv_layer)
|
80 |
+
|
81 |
+
conv_layer = M43_Sequential(
|
82 |
+
ConvNorm(self.dim_enc,
|
83 |
+
128,
|
84 |
+
kernel_size=1, stride=1,
|
85 |
+
padding=0,
|
86 |
+
dilation=1, w_init_gain='relu'),
|
87 |
+
GroupNorm_Mask(128//self.chs_grp, 128))
|
88 |
+
convolutions.append(conv_layer)
|
89 |
+
|
90 |
+
conv_layer = M43_Sequential(
|
91 |
+
ConvNorm(128,
|
92 |
+
32,
|
93 |
+
kernel_size=1, stride=1,
|
94 |
+
padding=0,
|
95 |
+
dilation=1, w_init_gain='relu'),
|
96 |
+
GroupNorm_Mask(32//self.chs_grp, 32))
|
97 |
+
convolutions.append(conv_layer)
|
98 |
+
|
99 |
+
conv_layer = M43_Sequential(
|
100 |
+
ConvNorm(32,
|
101 |
+
self.dim_neck,
|
102 |
+
kernel_size=1, stride=1,
|
103 |
+
padding=0,
|
104 |
+
dilation=1, w_init_gain='relu'),
|
105 |
+
GroupNorm_Mask(1, self.dim_neck))
|
106 |
+
convolutions.append(conv_layer)
|
107 |
+
|
108 |
+
self.convolutions = nn.ModuleList(convolutions)
|
109 |
+
|
110 |
+
|
111 |
+
def forward(self, x, mask):
|
112 |
+
|
113 |
+
for conv in self.convolutions:
|
114 |
+
x = F.relu(conv(x, mask))
|
115 |
+
|
116 |
+
codes = x.permute(0, 2, 1) * mask.unsqueeze(-1)
|
117 |
+
|
118 |
+
return codes
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
class Decoder(nn.Module):
|
123 |
+
"""Decoder module:
|
124 |
+
"""
|
125 |
+
def __init__(self, hparams):
|
126 |
+
super(Decoder, self).__init__()
|
127 |
+
self.dim_enc = hparams.dim_enc_sea
|
128 |
+
self.dim_emb = hparams.dim_spk
|
129 |
+
self.dim_freq = hparams.dim_freq_sp
|
130 |
+
self.dim_neck = hparams.dim_neck_sea
|
131 |
+
|
132 |
+
self.lstm = nn.LSTM(self.dim_neck+self.dim_emb,
|
133 |
+
1024, 3, batch_first=True)
|
134 |
+
|
135 |
+
self.linear_projection = LinearNorm(1024, self.dim_freq)
|
136 |
+
|
137 |
+
def forward(self, x):
|
138 |
+
|
139 |
+
outputs = self.lstm(x)[0]
|
140 |
+
|
141 |
+
decoder_output = self.linear_projection(outputs)
|
142 |
+
|
143 |
+
return decoder_output
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
class Generator(nn.Module):
|
149 |
+
"""Generator network."""
|
150 |
+
def __init__(self, hparams):
|
151 |
+
super(Generator, self).__init__()
|
152 |
+
|
153 |
+
self.encoder = Encoder(hparams)
|
154 |
+
self.decoder = Decoder(hparams)
|
155 |
+
|
156 |
+
def forward(self, x, c_trg):
|
157 |
+
|
158 |
+
x = x.transpose(2,1)
|
159 |
+
codes = self.encoder(x)
|
160 |
+
|
161 |
+
encoder_outputs = torch.cat((codes,
|
162 |
+
c_trg.unsqueeze(1).expand(-1,x.size(-1),-1)), dim=-1)
|
163 |
+
mel_outputs = self.decoder(encoder_outputs)
|
164 |
+
|
165 |
+
return mel_outputs
|
166 |
+
|
167 |
+
def encode(self, x, mask):
|
168 |
+
x = x.transpose(2,1)
|
169 |
+
codes = self.encoder(x, mask)
|
170 |
+
return codes
|
171 |
+
|
172 |
+
def decode(self, codes, c_trg):
|
173 |
+
encoder_outputs = torch.cat((codes,
|
174 |
+
c_trg.unsqueeze(1).expand(-1,codes.size(1),-1)), dim=-1)
|
175 |
+
mel_outputs = self.decoder(encoder_outputs)
|
176 |
+
|
177 |
+
return mel_outputs
|
178 |
+
|
179 |
+
|
180 |
+
|
181 |
+
class Encoder_2(nn.Module):
|
182 |
+
"""Encoder module:
|
183 |
+
"""
|
184 |
+
def __init__(self, hparams):
|
185 |
+
super().__init__()
|
186 |
+
|
187 |
+
self.dim_freq = hparams.dim_freq_sea
|
188 |
+
self.dim_enc = hparams.dim_enc_sea
|
189 |
+
self.chs_grp = hparams.chs_grp
|
190 |
+
self.dim_neck = hparams.dim_neck_sea
|
191 |
+
|
192 |
+
convolutions = []
|
193 |
+
for i in range(5):
|
194 |
+
conv_layer = M43_Sequential(
|
195 |
+
ConvNorm(self.dim_freq if i==0 else self.dim_enc,
|
196 |
+
self.dim_enc,
|
197 |
+
kernel_size=5, stride=1,
|
198 |
+
padding=2,
|
199 |
+
dilation=1, w_init_gain='relu'),
|
200 |
+
GroupNorm_Mask(self.dim_enc//self.chs_grp, self.dim_enc))
|
201 |
+
convolutions.append(conv_layer)
|
202 |
+
|
203 |
+
conv_layer = M43_Sequential(
|
204 |
+
ConvNorm(self.dim_enc,
|
205 |
+
128,
|
206 |
+
kernel_size=5, stride=1,
|
207 |
+
padding=2,
|
208 |
+
dilation=1, w_init_gain='relu'),
|
209 |
+
GroupNorm_Mask(128//self.chs_grp, 128))
|
210 |
+
convolutions.append(conv_layer)
|
211 |
+
|
212 |
+
conv_layer = M43_Sequential(
|
213 |
+
ConvNorm(128,
|
214 |
+
32,
|
215 |
+
kernel_size=5, stride=1,
|
216 |
+
padding=2,
|
217 |
+
dilation=1, w_init_gain='relu'),
|
218 |
+
GroupNorm_Mask(32//self.chs_grp, 32))
|
219 |
+
convolutions.append(conv_layer)
|
220 |
+
|
221 |
+
conv_layer = M43_Sequential(
|
222 |
+
ConvNorm(32,
|
223 |
+
self.dim_neck,
|
224 |
+
kernel_size=5, stride=1,
|
225 |
+
padding=2,
|
226 |
+
dilation=1, w_init_gain='linear'),
|
227 |
+
GroupNorm_Mask(1, self.dim_neck))
|
228 |
+
convolutions.append(conv_layer)
|
229 |
+
|
230 |
+
self.convolutions = nn.ModuleList(convolutions)
|
231 |
+
|
232 |
+
|
233 |
+
def forward(self, x, mask):
|
234 |
+
|
235 |
+
for i in range(len(self.convolutions)-1):
|
236 |
+
x = F.relu(self.convolutions[i](x, mask))
|
237 |
+
|
238 |
+
x = self.convolutions[-1](x, mask)
|
239 |
+
|
240 |
+
codes = x.permute(0, 2, 1) * mask.unsqueeze(-1)
|
241 |
+
|
242 |
+
return codes
|
onmt_modules/__init__.py
ADDED
File without changes
|
onmt_modules/average_attn.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""Average Attention module."""
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from .position_ffn import PositionwiseFeedForward
|
8 |
+
|
9 |
+
|
10 |
+
class AverageAttention(nn.Module):
|
11 |
+
"""
|
12 |
+
Average Attention module from
|
13 |
+
"Accelerating Neural Transformer via an Average Attention Network"
|
14 |
+
:cite:`DBLP:journals/corr/abs-1805-00631`.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
model_dim (int): the dimension of keys/values/queries,
|
18 |
+
must be divisible by head_count
|
19 |
+
dropout (float): dropout parameter
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, model_dim, dropout=0.1, aan_useffn=False):
|
23 |
+
self.model_dim = model_dim
|
24 |
+
self.aan_useffn = aan_useffn
|
25 |
+
super(AverageAttention, self).__init__()
|
26 |
+
if aan_useffn:
|
27 |
+
self.average_layer = PositionwiseFeedForward(model_dim, model_dim,
|
28 |
+
dropout)
|
29 |
+
self.gating_layer = nn.Linear(model_dim * 2, model_dim * 2)
|
30 |
+
|
31 |
+
def cumulative_average_mask(self, batch_size, inputs_len, device):
|
32 |
+
"""
|
33 |
+
Builds the mask to compute the cumulative average as described in
|
34 |
+
:cite:`DBLP:journals/corr/abs-1805-00631` -- Figure 3
|
35 |
+
|
36 |
+
Args:
|
37 |
+
batch_size (int): batch size
|
38 |
+
inputs_len (int): length of the inputs
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
(FloatTensor):
|
42 |
+
|
43 |
+
* A Tensor of shape ``(batch_size, input_len, input_len)``
|
44 |
+
"""
|
45 |
+
|
46 |
+
triangle = torch.tril(torch.ones(inputs_len, inputs_len,
|
47 |
+
dtype=torch.float, device=device))
|
48 |
+
weights = torch.ones(1, inputs_len, dtype=torch.float, device=device) \
|
49 |
+
/ torch.arange(1, inputs_len + 1, dtype=torch.float, device=device)
|
50 |
+
mask = triangle * weights.transpose(0, 1)
|
51 |
+
|
52 |
+
return mask.unsqueeze(0).expand(batch_size, inputs_len, inputs_len)
|
53 |
+
|
54 |
+
def cumulative_average(self, inputs, mask_or_step,
|
55 |
+
layer_cache=None, step=None):
|
56 |
+
"""
|
57 |
+
Computes the cumulative average as described in
|
58 |
+
:cite:`DBLP:journals/corr/abs-1805-00631` -- Equations (1) (5) (6)
|
59 |
+
|
60 |
+
Args:
|
61 |
+
inputs (FloatTensor): sequence to average
|
62 |
+
``(batch_size, input_len, dimension)``
|
63 |
+
mask_or_step: if cache is set, this is assumed
|
64 |
+
to be the current step of the
|
65 |
+
dynamic decoding. Otherwise, it is the mask matrix
|
66 |
+
used to compute the cumulative average.
|
67 |
+
layer_cache: a dictionary containing the cumulative average
|
68 |
+
of the previous step.
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
a tensor of the same shape and type as ``inputs``.
|
72 |
+
"""
|
73 |
+
|
74 |
+
if layer_cache is not None:
|
75 |
+
step = mask_or_step
|
76 |
+
average_attention = (inputs + step *
|
77 |
+
layer_cache["prev_g"]) / (step + 1)
|
78 |
+
layer_cache["prev_g"] = average_attention
|
79 |
+
return average_attention
|
80 |
+
else:
|
81 |
+
mask = mask_or_step
|
82 |
+
return torch.matmul(mask.to(inputs.dtype), inputs)
|
83 |
+
|
84 |
+
def forward(self, inputs, mask=None, layer_cache=None, step=None):
|
85 |
+
"""
|
86 |
+
Args:
|
87 |
+
inputs (FloatTensor): ``(batch_size, input_len, model_dim)``
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
(FloatTensor, FloatTensor):
|
91 |
+
|
92 |
+
* gating_outputs ``(batch_size, input_len, model_dim)``
|
93 |
+
* average_outputs average attention
|
94 |
+
``(batch_size, input_len, model_dim)``
|
95 |
+
"""
|
96 |
+
|
97 |
+
batch_size = inputs.size(0)
|
98 |
+
inputs_len = inputs.size(1)
|
99 |
+
average_outputs = self.cumulative_average(
|
100 |
+
inputs, self.cumulative_average_mask(batch_size,
|
101 |
+
inputs_len, inputs.device)
|
102 |
+
if layer_cache is None else step, layer_cache=layer_cache)
|
103 |
+
if self.aan_useffn:
|
104 |
+
average_outputs = self.average_layer(average_outputs)
|
105 |
+
gating_outputs = self.gating_layer(torch.cat((inputs,
|
106 |
+
average_outputs), -1))
|
107 |
+
input_gate, forget_gate = torch.chunk(gating_outputs, 2, dim=2)
|
108 |
+
gating_outputs = torch.sigmoid(input_gate) * inputs + \
|
109 |
+
torch.sigmoid(forget_gate) * average_outputs
|
110 |
+
|
111 |
+
return gating_outputs, average_outputs
|
onmt_modules/decoder.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .misc import aeq
|
5 |
+
|
6 |
+
|
7 |
+
class DecoderBase(nn.Module):
|
8 |
+
"""Abstract class for decoders.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
attentional (bool): The decoder returns non-empty attention.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, attentional=True):
|
15 |
+
super(DecoderBase, self).__init__()
|
16 |
+
self.attentional = attentional
|
17 |
+
|
18 |
+
@classmethod
|
19 |
+
def from_opt(cls, opt, embeddings):
|
20 |
+
"""Alternate constructor.
|
21 |
+
|
22 |
+
Subclasses should override this method.
|
23 |
+
"""
|
24 |
+
|
25 |
+
raise NotImplementedError
|
onmt_modules/decoder_transformer.py
ADDED
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Implementation of "Attention is All You Need"
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from .decoder import DecoderBase
|
9 |
+
from .multi_headed_attn import MultiHeadedAttention
|
10 |
+
from .average_attn import AverageAttention
|
11 |
+
from .position_ffn import PositionwiseFeedForward
|
12 |
+
from .misc import sequence_mask
|
13 |
+
|
14 |
+
|
15 |
+
class TransformerDecoderLayer(nn.Module):
|
16 |
+
"""Transformer Decoder layer block in Pre-Norm style.
|
17 |
+
Pre-Norm style is an improvement w.r.t. Original paper's Post-Norm style,
|
18 |
+
providing better converge speed and performance. This is also the actual
|
19 |
+
implementation in tensor2tensor and also avalable in fairseq.
|
20 |
+
See https://tunz.kr/post/4 and :cite:`DeeperTransformer`.
|
21 |
+
|
22 |
+
.. mermaid::
|
23 |
+
|
24 |
+
graph LR
|
25 |
+
%% "*SubLayer" can be self-attn, src-attn or feed forward block
|
26 |
+
A(input) --> B[Norm]
|
27 |
+
B --> C["*SubLayer"]
|
28 |
+
C --> D[Drop]
|
29 |
+
D --> E((+))
|
30 |
+
A --> E
|
31 |
+
E --> F(out)
|
32 |
+
|
33 |
+
|
34 |
+
Args:
|
35 |
+
d_model (int): the dimension of keys/values/queries in
|
36 |
+
:class:`MultiHeadedAttention`, also the input size of
|
37 |
+
the first-layer of the :class:`PositionwiseFeedForward`.
|
38 |
+
heads (int): the number of heads for MultiHeadedAttention.
|
39 |
+
d_ff (int): the second-layer of the :class:`PositionwiseFeedForward`.
|
40 |
+
dropout (float): dropout in residual, self-attn(dot) and feed-forward
|
41 |
+
attention_dropout (float): dropout in context_attn (and self-attn(avg))
|
42 |
+
self_attn_type (string): type of self-attention scaled-dot, average
|
43 |
+
max_relative_positions (int):
|
44 |
+
Max distance between inputs in relative positions representations
|
45 |
+
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
|
46 |
+
full_context_alignment (bool):
|
47 |
+
whether enable an extra full context decoder forward for alignment
|
48 |
+
alignment_heads (int):
|
49 |
+
N. of cross attention heads to use for alignment guiding
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, d_model, heads, d_ff, dropout, attention_dropout,
|
53 |
+
self_attn_type="scaled-dot", max_relative_positions=0,
|
54 |
+
aan_useffn=False, full_context_alignment=False,
|
55 |
+
alignment_heads=0):
|
56 |
+
super(TransformerDecoderLayer, self).__init__()
|
57 |
+
|
58 |
+
if self_attn_type == "scaled-dot":
|
59 |
+
self.self_attn = MultiHeadedAttention(
|
60 |
+
heads, d_model, dropout=attention_dropout,
|
61 |
+
max_relative_positions=max_relative_positions)
|
62 |
+
elif self_attn_type == "average":
|
63 |
+
self.self_attn = AverageAttention(d_model,
|
64 |
+
dropout=attention_dropout,
|
65 |
+
aan_useffn=aan_useffn)
|
66 |
+
|
67 |
+
self.context_attn = MultiHeadedAttention(
|
68 |
+
heads, d_model, dropout=attention_dropout)
|
69 |
+
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
|
70 |
+
self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6)
|
71 |
+
self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6)
|
72 |
+
self.drop = nn.Dropout(dropout)
|
73 |
+
self.full_context_alignment = full_context_alignment
|
74 |
+
self.alignment_heads = alignment_heads
|
75 |
+
|
76 |
+
def forward(self, *args, **kwargs):
|
77 |
+
""" Extend `_forward` for (possibly) multiple decoder pass:
|
78 |
+
Always a default (future masked) decoder forward pass,
|
79 |
+
Possibly a second future aware decoder pass for joint learn
|
80 |
+
full context alignement, :cite:`garg2019jointly`.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
* All arguments of _forward.
|
84 |
+
with_align (bool): whether return alignment attention.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
(FloatTensor, FloatTensor, FloatTensor or None):
|
88 |
+
|
89 |
+
* output ``(batch_size, T, model_dim)``
|
90 |
+
* top_attn ``(batch_size, T, src_len)``
|
91 |
+
* attn_align ``(batch_size, T, src_len)`` or None
|
92 |
+
"""
|
93 |
+
with_align = kwargs.pop('with_align', False)
|
94 |
+
output, attns = self._forward(*args, **kwargs)
|
95 |
+
top_attn = attns[:, 0, :, :].contiguous()
|
96 |
+
attn_align = None
|
97 |
+
if with_align:
|
98 |
+
if self.full_context_alignment:
|
99 |
+
# return _, (B, Q_len, K_len)
|
100 |
+
_, attns = self._forward(*args, **kwargs, future=True)
|
101 |
+
|
102 |
+
if self.alignment_heads > 0:
|
103 |
+
attns = attns[:, :self.alignment_heads, :, :].contiguous()
|
104 |
+
# layer average attention across heads, get ``(B, Q, K)``
|
105 |
+
# Case 1: no full_context, no align heads -> layer avg baseline
|
106 |
+
# Case 2: no full_context, 1 align heads -> guided align
|
107 |
+
# Case 3: full_context, 1 align heads -> full cte guided align
|
108 |
+
attn_align = attns.mean(dim=1)
|
109 |
+
return output, top_attn, attn_align
|
110 |
+
|
111 |
+
def _forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask,
|
112 |
+
layer_cache=None, step=None, future=False):
|
113 |
+
""" A naive forward pass for transformer decoder.
|
114 |
+
|
115 |
+
# T: could be 1 in the case of stepwise decoding or tgt_len
|
116 |
+
|
117 |
+
Args:
|
118 |
+
inputs (FloatTensor): ``(batch_size, T, model_dim)``
|
119 |
+
memory_bank (FloatTensor): ``(batch_size, src_len, model_dim)``
|
120 |
+
src_pad_mask (LongTensor): ``(batch_size, 1, src_len)``
|
121 |
+
tgt_pad_mask (LongTensor): ``(batch_size, 1, T)``
|
122 |
+
layer_cache (dict or None): cached layer info when stepwise decode
|
123 |
+
step (int or None): stepwise decoding counter
|
124 |
+
future (bool): If set True, do not apply future_mask.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
(FloatTensor, FloatTensor):
|
128 |
+
|
129 |
+
* output ``(batch_size, T, model_dim)``
|
130 |
+
* attns ``(batch_size, head, T, src_len)``
|
131 |
+
|
132 |
+
"""
|
133 |
+
dec_mask = None
|
134 |
+
|
135 |
+
if step is None:
|
136 |
+
tgt_len = tgt_pad_mask.size(-1)
|
137 |
+
if not future: # apply future_mask, result mask in (B, T, T)
|
138 |
+
future_mask = torch.ones(
|
139 |
+
[tgt_len, tgt_len],
|
140 |
+
device=tgt_pad_mask.device,
|
141 |
+
dtype=torch.uint8)
|
142 |
+
future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len)
|
143 |
+
# BoolTensor was introduced in pytorch 1.2
|
144 |
+
try:
|
145 |
+
future_mask = future_mask.bool()
|
146 |
+
except AttributeError:
|
147 |
+
pass
|
148 |
+
dec_mask = torch.gt(tgt_pad_mask + future_mask, 0)
|
149 |
+
else: # only mask padding, result mask in (B, 1, T)
|
150 |
+
dec_mask = tgt_pad_mask
|
151 |
+
|
152 |
+
input_norm = self.layer_norm_1(inputs)
|
153 |
+
|
154 |
+
if isinstance(self.self_attn, MultiHeadedAttention):
|
155 |
+
query, _ = self.self_attn(input_norm, input_norm, input_norm,
|
156 |
+
mask=dec_mask,
|
157 |
+
layer_cache=layer_cache,
|
158 |
+
attn_type="self")
|
159 |
+
elif isinstance(self.self_attn, AverageAttention):
|
160 |
+
query, _ = self.self_attn(input_norm, mask=dec_mask,
|
161 |
+
layer_cache=layer_cache, step=step)
|
162 |
+
|
163 |
+
query = self.drop(query) + inputs
|
164 |
+
|
165 |
+
query_norm = self.layer_norm_2(query)
|
166 |
+
mid, attns = self.context_attn(memory_bank, memory_bank, query_norm,
|
167 |
+
mask=src_pad_mask,
|
168 |
+
layer_cache=layer_cache,
|
169 |
+
attn_type="context")
|
170 |
+
output = self.feed_forward(self.drop(mid) + query)
|
171 |
+
|
172 |
+
return output, attns
|
173 |
+
|
174 |
+
def update_dropout(self, dropout, attention_dropout):
|
175 |
+
self.self_attn.update_dropout(attention_dropout)
|
176 |
+
self.context_attn.update_dropout(attention_dropout)
|
177 |
+
self.feed_forward.update_dropout(dropout)
|
178 |
+
self.drop.p = dropout
|
179 |
+
|
180 |
+
|
181 |
+
class TransformerDecoder(DecoderBase):
|
182 |
+
"""The Transformer decoder from "Attention is All You Need".
|
183 |
+
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17`
|
184 |
+
|
185 |
+
.. mermaid::
|
186 |
+
|
187 |
+
graph BT
|
188 |
+
A[input]
|
189 |
+
B[multi-head self-attn]
|
190 |
+
BB[multi-head src-attn]
|
191 |
+
C[feed forward]
|
192 |
+
O[output]
|
193 |
+
A --> B
|
194 |
+
B --> BB
|
195 |
+
BB --> C
|
196 |
+
C --> O
|
197 |
+
|
198 |
+
|
199 |
+
Args:
|
200 |
+
num_layers (int): number of encoder layers.
|
201 |
+
d_model (int): size of the model
|
202 |
+
heads (int): number of heads
|
203 |
+
d_ff (int): size of the inner FF layer
|
204 |
+
copy_attn (bool): if using a separate copy attention
|
205 |
+
self_attn_type (str): type of self-attention scaled-dot, average
|
206 |
+
dropout (float): dropout in residual, self-attn(dot) and feed-forward
|
207 |
+
attention_dropout (float): dropout in context_attn (and self-attn(avg))
|
208 |
+
embeddings (onmt.modules.Embeddings):
|
209 |
+
embeddings to use, should have positional encodings
|
210 |
+
max_relative_positions (int):
|
211 |
+
Max distance between inputs in relative positions representations
|
212 |
+
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
|
213 |
+
full_context_alignment (bool):
|
214 |
+
whether enable an extra full context decoder forward for alignment
|
215 |
+
alignment_layer (int): N° Layer to supervise with for alignment guiding
|
216 |
+
alignment_heads (int):
|
217 |
+
N. of cross attention heads to use for alignment guiding
|
218 |
+
"""
|
219 |
+
|
220 |
+
def __init__(self, num_layers, d_model, heads, d_ff,
|
221 |
+
copy_attn, self_attn_type, dropout, attention_dropout,
|
222 |
+
embeddings, max_relative_positions, aan_useffn,
|
223 |
+
full_context_alignment, alignment_layer,
|
224 |
+
alignment_heads):
|
225 |
+
super(TransformerDecoder, self).__init__()
|
226 |
+
|
227 |
+
self.embeddings = embeddings
|
228 |
+
|
229 |
+
# Decoder State
|
230 |
+
self.state = {}
|
231 |
+
|
232 |
+
self.transformer_layers = nn.ModuleList(
|
233 |
+
[TransformerDecoderLayer(d_model, heads, d_ff, dropout,
|
234 |
+
attention_dropout, self_attn_type=self_attn_type,
|
235 |
+
max_relative_positions=max_relative_positions,
|
236 |
+
aan_useffn=aan_useffn,
|
237 |
+
full_context_alignment=full_context_alignment,
|
238 |
+
alignment_heads=alignment_heads)
|
239 |
+
for i in range(num_layers)])
|
240 |
+
|
241 |
+
# previously, there was a GlobalAttention module here for copy
|
242 |
+
# attention. But it was never actually used -- the "copy" attention
|
243 |
+
# just reuses the context attention.
|
244 |
+
self._copy = copy_attn
|
245 |
+
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
246 |
+
|
247 |
+
self.alignment_layer = alignment_layer
|
248 |
+
|
249 |
+
@classmethod
|
250 |
+
def from_opt(cls, opt, embeddings):
|
251 |
+
"""Alternate constructor."""
|
252 |
+
return cls(
|
253 |
+
opt.dec_layers,
|
254 |
+
opt.dec_rnn_size,
|
255 |
+
opt.heads,
|
256 |
+
opt.transformer_ff,
|
257 |
+
opt.copy_attn,
|
258 |
+
opt.self_attn_type,
|
259 |
+
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
|
260 |
+
opt.attention_dropout[0] if type(opt.attention_dropout)
|
261 |
+
is list else opt.dropout,
|
262 |
+
embeddings,
|
263 |
+
opt.max_relative_positions,
|
264 |
+
opt.aan_useffn,
|
265 |
+
opt.full_context_alignment,
|
266 |
+
opt.alignment_layer,
|
267 |
+
alignment_heads=opt.alignment_heads)
|
268 |
+
|
269 |
+
def init_state(self, src, memory_bank, enc_hidden):
|
270 |
+
"""Initialize decoder state."""
|
271 |
+
self.state["src"] = src
|
272 |
+
self.state["cache"] = None
|
273 |
+
|
274 |
+
def map_state(self, fn):
|
275 |
+
def _recursive_map(struct, batch_dim=0):
|
276 |
+
for k, v in struct.items():
|
277 |
+
if v is not None:
|
278 |
+
if isinstance(v, dict):
|
279 |
+
_recursive_map(v)
|
280 |
+
else:
|
281 |
+
struct[k] = fn(v, batch_dim)
|
282 |
+
|
283 |
+
self.state["src"] = fn(self.state["src"], 1)
|
284 |
+
if self.state["cache"] is not None:
|
285 |
+
_recursive_map(self.state["cache"])
|
286 |
+
|
287 |
+
def detach_state(self):
|
288 |
+
self.state["src"] = self.state["src"].detach()
|
289 |
+
|
290 |
+
def forward(self, tgt, memory_bank, step=None, **kwargs):
|
291 |
+
"""Decode, possibly stepwise."""
|
292 |
+
if step == 0:
|
293 |
+
self._init_cache(memory_bank)
|
294 |
+
|
295 |
+
tgt_words = tgt[:, :, 0].transpose(0, 1)
|
296 |
+
|
297 |
+
emb = self.embeddings(tgt, step=step)
|
298 |
+
assert emb.dim() == 3 # len x batch x embedding_dim
|
299 |
+
|
300 |
+
output = emb.transpose(0, 1).contiguous()
|
301 |
+
src_memory_bank = memory_bank.transpose(0, 1).contiguous()
|
302 |
+
|
303 |
+
pad_idx = self.embeddings.word_padding_idx
|
304 |
+
src_lens = kwargs["memory_lengths"]
|
305 |
+
src_max_len = self.state["src"].shape[0]
|
306 |
+
src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1)
|
307 |
+
tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt]
|
308 |
+
|
309 |
+
with_align = kwargs.pop('with_align', False)
|
310 |
+
attn_aligns = []
|
311 |
+
|
312 |
+
for i, layer in enumerate(self.transformer_layers):
|
313 |
+
layer_cache = self.state["cache"]["layer_{}".format(i)] \
|
314 |
+
if step is not None else None
|
315 |
+
output, attn, attn_align = layer(
|
316 |
+
output,
|
317 |
+
src_memory_bank,
|
318 |
+
src_pad_mask,
|
319 |
+
tgt_pad_mask,
|
320 |
+
layer_cache=layer_cache,
|
321 |
+
step=step,
|
322 |
+
with_align=with_align)
|
323 |
+
if attn_align is not None:
|
324 |
+
attn_aligns.append(attn_align)
|
325 |
+
|
326 |
+
output = self.layer_norm(output)
|
327 |
+
dec_outs = output.transpose(0, 1).contiguous()
|
328 |
+
attn = attn.transpose(0, 1).contiguous()
|
329 |
+
|
330 |
+
attns = {"std": attn}
|
331 |
+
if self._copy:
|
332 |
+
attns["copy"] = attn
|
333 |
+
if with_align:
|
334 |
+
attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)`
|
335 |
+
# attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg
|
336 |
+
|
337 |
+
# TODO change the way attns is returned dict => list or tuple (onnx)
|
338 |
+
return dec_outs, attns
|
339 |
+
|
340 |
+
def _init_cache(self, memory_bank):
|
341 |
+
self.state["cache"] = {}
|
342 |
+
batch_size = memory_bank.size(1)
|
343 |
+
depth = memory_bank.size(-1)
|
344 |
+
|
345 |
+
for i, layer in enumerate(self.transformer_layers):
|
346 |
+
layer_cache = {"memory_keys": None, "memory_values": None}
|
347 |
+
if isinstance(layer.self_attn, AverageAttention):
|
348 |
+
layer_cache["prev_g"] = torch.zeros((batch_size, 1, depth),
|
349 |
+
device=memory_bank.device)
|
350 |
+
else:
|
351 |
+
layer_cache["self_keys"] = None
|
352 |
+
layer_cache["self_values"] = None
|
353 |
+
self.state["cache"]["layer_{}".format(i)] = layer_cache
|
354 |
+
|
355 |
+
def update_dropout(self, dropout, attention_dropout):
|
356 |
+
self.embeddings.update_dropout(dropout)
|
357 |
+
for layer in self.transformer_layers:
|
358 |
+
layer.update_dropout(dropout, attention_dropout)
|
onmt_modules/embeddings.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Embeddings module """
|
2 |
+
import math
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
|
9 |
+
class PositionalEncoding(nn.Module):
|
10 |
+
"""Sinusoidal positional encoding for non-recurrent neural networks.
|
11 |
+
|
12 |
+
Implementation based on "Attention Is All You Need"
|
13 |
+
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17`
|
14 |
+
|
15 |
+
Args:
|
16 |
+
dropout (float): dropout parameter
|
17 |
+
dim (int): embedding size
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, dropout, dim, max_len=5000):
|
21 |
+
if dim % 2 != 0:
|
22 |
+
raise ValueError("Cannot use sin/cos positional encoding with "
|
23 |
+
"odd dim (got dim={:d})".format(dim))
|
24 |
+
pe = torch.zeros(max_len, dim)
|
25 |
+
position = torch.arange(0, max_len).unsqueeze(1)
|
26 |
+
div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *
|
27 |
+
-(math.log(10000.0) / dim)))
|
28 |
+
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
29 |
+
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
30 |
+
pe = pe.unsqueeze(1)
|
31 |
+
super(PositionalEncoding, self).__init__()
|
32 |
+
self.register_buffer('pe', pe)
|
33 |
+
self.dropout = nn.Dropout(p=dropout)
|
34 |
+
self.dim = dim
|
35 |
+
|
36 |
+
def forward(self, emb, step=None):
|
37 |
+
"""Embed inputs.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
emb (FloatTensor): Sequence of word vectors
|
41 |
+
``(seq_len, batch_size, self.dim)``
|
42 |
+
step (int or NoneType): If stepwise (``seq_len = 1``), use
|
43 |
+
the encoding for this position.
|
44 |
+
"""
|
45 |
+
|
46 |
+
emb = emb * math.sqrt(self.dim)
|
47 |
+
if step is None:
|
48 |
+
emb = emb + self.pe[:emb.size(0)]
|
49 |
+
else:
|
50 |
+
emb = emb + self.pe[step]
|
51 |
+
emb = self.dropout(emb)
|
52 |
+
return emb
|
onmt_modules/encoder.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Base class for encoders and generic multi encoders."""
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from .misc import aeq
|
6 |
+
|
7 |
+
|
8 |
+
class EncoderBase(nn.Module):
|
9 |
+
"""
|
10 |
+
Base encoder class. Specifies the interface used by different encoder types
|
11 |
+
and required by :class:`onmt.Models.NMTModel`.
|
12 |
+
|
13 |
+
.. mermaid::
|
14 |
+
|
15 |
+
graph BT
|
16 |
+
A[Input]
|
17 |
+
subgraph RNN
|
18 |
+
C[Pos 1]
|
19 |
+
D[Pos 2]
|
20 |
+
E[Pos N]
|
21 |
+
end
|
22 |
+
F[Memory_Bank]
|
23 |
+
G[Final]
|
24 |
+
A-->C
|
25 |
+
A-->D
|
26 |
+
A-->E
|
27 |
+
C-->F
|
28 |
+
D-->F
|
29 |
+
E-->F
|
30 |
+
E-->G
|
31 |
+
"""
|
32 |
+
|
33 |
+
@classmethod
|
34 |
+
def from_opt(cls, opt, embeddings=None):
|
35 |
+
raise NotImplementedError
|
36 |
+
|
37 |
+
def _check_args(self, src, lengths=None, hidden=None):
|
38 |
+
n_batch = src.size(1)
|
39 |
+
if lengths is not None:
|
40 |
+
n_batch_, = lengths.size()
|
41 |
+
aeq(n_batch, n_batch_)
|
42 |
+
|
43 |
+
def forward(self, src, lengths=None):
|
44 |
+
"""
|
45 |
+
Args:
|
46 |
+
src (LongTensor):
|
47 |
+
padded sequences of sparse indices ``(src_len, batch, nfeat)``
|
48 |
+
lengths (LongTensor): length of each sequence ``(batch,)``
|
49 |
+
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
(FloatTensor, FloatTensor):
|
53 |
+
|
54 |
+
* final encoder state, used to initialize decoder
|
55 |
+
* memory bank for attention, ``(src_len, batch, hidden)``
|
56 |
+
"""
|
57 |
+
|
58 |
+
raise NotImplementedError
|
onmt_modules/encoder_transformer.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Implementation of "Attention is All You Need"
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from .encoder import EncoderBase
|
8 |
+
from .multi_headed_attn import MultiHeadedAttention
|
9 |
+
from .position_ffn import PositionwiseFeedForward
|
10 |
+
from .misc import sequence_mask
|
11 |
+
|
12 |
+
|
13 |
+
class TransformerEncoderLayer(nn.Module):
|
14 |
+
"""
|
15 |
+
A single layer of the transformer encoder.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
d_model (int): the dimension of keys/values/queries in
|
19 |
+
MultiHeadedAttention, also the input size of
|
20 |
+
the first-layer of the PositionwiseFeedForward.
|
21 |
+
heads (int): the number of head for MultiHeadedAttention.
|
22 |
+
d_ff (int): the second-layer of the PositionwiseFeedForward.
|
23 |
+
dropout (float): dropout probability(0-1.0).
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, d_model, heads, d_ff, dropout, attention_dropout,
|
27 |
+
max_relative_positions=0):
|
28 |
+
super(TransformerEncoderLayer, self).__init__()
|
29 |
+
|
30 |
+
self.self_attn = MultiHeadedAttention(
|
31 |
+
heads, d_model, dropout=attention_dropout,
|
32 |
+
max_relative_positions=max_relative_positions)
|
33 |
+
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
|
34 |
+
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
35 |
+
self.dropout = nn.Dropout(dropout)
|
36 |
+
|
37 |
+
def forward(self, inputs, mask):
|
38 |
+
"""
|
39 |
+
Args:
|
40 |
+
inputs (FloatTensor): ``(batch_size, src_len, model_dim)``
|
41 |
+
mask (LongTensor): ``(batch_size, 1, src_len)``
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
(FloatTensor):
|
45 |
+
|
46 |
+
* outputs ``(batch_size, src_len, model_dim)``
|
47 |
+
"""
|
48 |
+
input_norm = self.layer_norm(inputs)
|
49 |
+
context, _ = self.self_attn(input_norm, input_norm, input_norm,
|
50 |
+
mask=mask, attn_type="self")
|
51 |
+
out = self.dropout(context) + inputs
|
52 |
+
return self.feed_forward(out)
|
53 |
+
|
54 |
+
def update_dropout(self, dropout, attention_dropout):
|
55 |
+
self.self_attn.update_dropout(attention_dropout)
|
56 |
+
self.feed_forward.update_dropout(dropout)
|
57 |
+
self.dropout.p = dropout
|
58 |
+
|
59 |
+
|
60 |
+
class TransformerEncoder(EncoderBase):
|
61 |
+
"""The Transformer encoder from "Attention is All You Need"
|
62 |
+
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17`
|
63 |
+
|
64 |
+
.. mermaid::
|
65 |
+
|
66 |
+
graph BT
|
67 |
+
A[input]
|
68 |
+
B[multi-head self-attn]
|
69 |
+
C[feed forward]
|
70 |
+
O[output]
|
71 |
+
A --> B
|
72 |
+
B --> C
|
73 |
+
C --> O
|
74 |
+
|
75 |
+
Args:
|
76 |
+
num_layers (int): number of encoder layers
|
77 |
+
d_model (int): size of the model
|
78 |
+
heads (int): number of heads
|
79 |
+
d_ff (int): size of the inner FF layer
|
80 |
+
dropout (float): dropout parameters
|
81 |
+
embeddings (onmt.modules.Embeddings):
|
82 |
+
embeddings to use, should have positional encodings
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
(torch.FloatTensor, torch.FloatTensor):
|
86 |
+
|
87 |
+
* embeddings ``(src_len, batch_size, model_dim)``
|
88 |
+
* memory_bank ``(src_len, batch_size, model_dim)``
|
89 |
+
"""
|
90 |
+
|
91 |
+
def __init__(self, num_layers, d_model, heads, d_ff, dropout,
|
92 |
+
attention_dropout, embeddings, max_relative_positions):
|
93 |
+
super(TransformerEncoder, self).__init__()
|
94 |
+
|
95 |
+
self.embeddings = embeddings
|
96 |
+
self.transformer = nn.ModuleList(
|
97 |
+
[TransformerEncoderLayer(
|
98 |
+
d_model, heads, d_ff, dropout, attention_dropout,
|
99 |
+
max_relative_positions=max_relative_positions)
|
100 |
+
for i in range(num_layers)])
|
101 |
+
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
102 |
+
|
103 |
+
@classmethod
|
104 |
+
def from_opt(cls, opt, embeddings):
|
105 |
+
"""Alternate constructor."""
|
106 |
+
return cls(
|
107 |
+
opt.enc_layers,
|
108 |
+
opt.enc_rnn_size,
|
109 |
+
opt.heads,
|
110 |
+
opt.transformer_ff,
|
111 |
+
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
|
112 |
+
opt.attention_dropout[0] if type(opt.attention_dropout)
|
113 |
+
is list else opt.attention_dropout,
|
114 |
+
embeddings,
|
115 |
+
opt.max_relative_positions)
|
116 |
+
|
117 |
+
def forward(self, src, lengths=None):
|
118 |
+
"""See :func:`EncoderBase.forward()`"""
|
119 |
+
self._check_args(src, lengths)
|
120 |
+
|
121 |
+
emb = self.embeddings(src)
|
122 |
+
|
123 |
+
out = emb.transpose(0, 1).contiguous()
|
124 |
+
mask = ~sequence_mask(lengths).unsqueeze(1)
|
125 |
+
# Run the forward pass of every layer of the tranformer.
|
126 |
+
for layer in self.transformer:
|
127 |
+
out = layer(out, mask)
|
128 |
+
out = self.layer_norm(out)
|
129 |
+
|
130 |
+
return emb, out.transpose(0, 1).contiguous(), lengths
|
131 |
+
|
132 |
+
def update_dropout(self, dropout, attention_dropout):
|
133 |
+
self.embeddings.update_dropout(dropout)
|
134 |
+
for layer in self.transformer:
|
135 |
+
layer.update_dropout(dropout, attention_dropout)
|
onmt_modules/misc.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
import inspect
|
6 |
+
from itertools import islice, repeat
|
7 |
+
import os
|
8 |
+
|
9 |
+
|
10 |
+
def split_corpus(path, shard_size, default=None):
|
11 |
+
"""yield a `list` containing `shard_size` line of `path`,
|
12 |
+
or repeatly generate `default` if `path` is None.
|
13 |
+
"""
|
14 |
+
if path is not None:
|
15 |
+
return _split_corpus(path, shard_size)
|
16 |
+
else:
|
17 |
+
return repeat(default)
|
18 |
+
|
19 |
+
|
20 |
+
def _split_corpus(path, shard_size):
|
21 |
+
"""Yield a `list` containing `shard_size` line of `path`.
|
22 |
+
"""
|
23 |
+
with open(path, "rb") as f:
|
24 |
+
if shard_size <= 0:
|
25 |
+
yield f.readlines()
|
26 |
+
else:
|
27 |
+
while True:
|
28 |
+
shard = list(islice(f, shard_size))
|
29 |
+
if not shard:
|
30 |
+
break
|
31 |
+
yield shard
|
32 |
+
|
33 |
+
|
34 |
+
def aeq(*args):
|
35 |
+
"""
|
36 |
+
Assert all arguments have the same value
|
37 |
+
"""
|
38 |
+
arguments = (arg for arg in args)
|
39 |
+
first = next(arguments)
|
40 |
+
assert all(arg == first for arg in arguments), \
|
41 |
+
"Not all arguments have the same value: " + str(args)
|
42 |
+
|
43 |
+
|
44 |
+
def sequence_mask(lengths, max_len=None):
|
45 |
+
"""
|
46 |
+
Creates a boolean mask from sequence lengths.
|
47 |
+
"""
|
48 |
+
batch_size = lengths.numel()
|
49 |
+
max_len = max_len or lengths.max()
|
50 |
+
return (torch.arange(0, max_len, device=lengths.device)
|
51 |
+
.type_as(lengths)
|
52 |
+
.repeat(batch_size, 1)
|
53 |
+
.lt(lengths.unsqueeze(1)))
|
54 |
+
|
55 |
+
|
56 |
+
def tile(x, count, dim=0):
|
57 |
+
"""
|
58 |
+
Tiles x on dimension dim count times.
|
59 |
+
"""
|
60 |
+
perm = list(range(len(x.size())))
|
61 |
+
if dim != 0:
|
62 |
+
perm[0], perm[dim] = perm[dim], perm[0]
|
63 |
+
x = x.permute(perm).contiguous()
|
64 |
+
out_size = list(x.size())
|
65 |
+
out_size[0] *= count
|
66 |
+
batch = x.size(0)
|
67 |
+
x = x.view(batch, -1) \
|
68 |
+
.transpose(0, 1) \
|
69 |
+
.repeat(count, 1) \
|
70 |
+
.transpose(0, 1) \
|
71 |
+
.contiguous() \
|
72 |
+
.view(*out_size)
|
73 |
+
if dim != 0:
|
74 |
+
x = x.permute(perm).contiguous()
|
75 |
+
return x
|
76 |
+
|
77 |
+
|
78 |
+
def use_gpu(opt):
|
79 |
+
"""
|
80 |
+
Creates a boolean if gpu used
|
81 |
+
"""
|
82 |
+
return (hasattr(opt, 'gpu_ranks') and len(opt.gpu_ranks) > 0) or \
|
83 |
+
(hasattr(opt, 'gpu') and opt.gpu > -1)
|
84 |
+
|
85 |
+
|
86 |
+
def set_random_seed(seed, is_cuda):
|
87 |
+
"""Sets the random seed."""
|
88 |
+
if seed > 0:
|
89 |
+
torch.manual_seed(seed)
|
90 |
+
# this one is needed for torchtext random call (shuffled iterator)
|
91 |
+
# in multi gpu it ensures datasets are read in the same order
|
92 |
+
random.seed(seed)
|
93 |
+
# some cudnn methods can be random even after fixing the seed
|
94 |
+
# unless you tell it to be deterministic
|
95 |
+
torch.backends.cudnn.deterministic = True
|
96 |
+
|
97 |
+
if is_cuda and seed > 0:
|
98 |
+
# These ensure same initialization in multi gpu mode
|
99 |
+
torch.cuda.manual_seed(seed)
|
100 |
+
|
101 |
+
|
102 |
+
def generate_relative_positions_matrix(length, max_relative_positions,
|
103 |
+
cache=False):
|
104 |
+
"""Generate the clipped relative positions matrix
|
105 |
+
for a given length and maximum relative positions"""
|
106 |
+
if cache:
|
107 |
+
distance_mat = torch.arange(-length+1, 1, 1).unsqueeze(0)
|
108 |
+
else:
|
109 |
+
range_vec = torch.arange(length)
|
110 |
+
range_mat = range_vec.unsqueeze(-1).expand(-1, length).transpose(0, 1)
|
111 |
+
distance_mat = range_mat - range_mat.transpose(0, 1)
|
112 |
+
distance_mat_clipped = torch.clamp(distance_mat,
|
113 |
+
min=-max_relative_positions,
|
114 |
+
max=max_relative_positions)
|
115 |
+
# Shift values to be >= 0
|
116 |
+
final_mat = distance_mat_clipped + max_relative_positions
|
117 |
+
return final_mat
|
118 |
+
|
119 |
+
|
120 |
+
def relative_matmul(x, z, transpose):
|
121 |
+
"""Helper function for relative positions attention."""
|
122 |
+
batch_size = x.shape[0]
|
123 |
+
heads = x.shape[1]
|
124 |
+
length = x.shape[2]
|
125 |
+
x_t = x.permute(2, 0, 1, 3)
|
126 |
+
x_t_r = x_t.reshape(length, heads * batch_size, -1)
|
127 |
+
if transpose:
|
128 |
+
z_t = z.transpose(1, 2)
|
129 |
+
x_tz_matmul = torch.matmul(x_t_r, z_t)
|
130 |
+
else:
|
131 |
+
x_tz_matmul = torch.matmul(x_t_r, z)
|
132 |
+
x_tz_matmul_r = x_tz_matmul.reshape(length, batch_size, heads, -1)
|
133 |
+
x_tz_matmul_r_t = x_tz_matmul_r.permute(1, 2, 0, 3)
|
134 |
+
return x_tz_matmul_r_t
|
135 |
+
|
136 |
+
|
137 |
+
def fn_args(fun):
|
138 |
+
"""Returns the list of function arguments name."""
|
139 |
+
return inspect.getfullargspec(fun).args
|
140 |
+
|
141 |
+
|
142 |
+
def report_matrix(row_label, column_label, matrix):
|
143 |
+
header_format = "{:>10.10} " + "{:>10.7} " * len(row_label)
|
144 |
+
row_format = "{:>10.10} " + "{:>10.7f} " * len(row_label)
|
145 |
+
output = header_format.format("", *row_label) + '\n'
|
146 |
+
for word, row in zip(column_label, matrix):
|
147 |
+
max_index = row.index(max(row))
|
148 |
+
row_format = row_format.replace(
|
149 |
+
"{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
|
150 |
+
row_format = row_format.replace(
|
151 |
+
"{:*>10.7f} ", "{:>10.7f} ", max_index)
|
152 |
+
output += row_format.format(word, *row) + '\n'
|
153 |
+
row_format = "{:>10.10} " + "{:>10.7f} " * len(row_label)
|
154 |
+
return output
|
155 |
+
|
156 |
+
|
157 |
+
def check_model_config(model_config, root):
|
158 |
+
# we need to check the model path + any tokenizer path
|
159 |
+
for model in model_config["models"]:
|
160 |
+
model_path = os.path.join(root, model)
|
161 |
+
if not os.path.exists(model_path):
|
162 |
+
raise FileNotFoundError(
|
163 |
+
"{} from model {} does not exist".format(
|
164 |
+
model_path, model_config["id"]))
|
165 |
+
if "tokenizer" in model_config.keys():
|
166 |
+
if "params" in model_config["tokenizer"].keys():
|
167 |
+
for k, v in model_config["tokenizer"]["params"].items():
|
168 |
+
if k.endswith("path"):
|
169 |
+
tok_path = os.path.join(root, v)
|
170 |
+
if not os.path.exists(tok_path):
|
171 |
+
raise FileNotFoundError(
|
172 |
+
"{} from model {} does not exist".format(
|
173 |
+
tok_path, model_config["id"]))
|
onmt_modules/multi_headed_attn.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Multi-Head Attention module """
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from .misc import generate_relative_positions_matrix,\
|
7 |
+
relative_matmul
|
8 |
+
# from onmt.utils.misc import aeq
|
9 |
+
|
10 |
+
|
11 |
+
class MultiHeadedAttention(nn.Module):
|
12 |
+
"""Multi-Head Attention module from "Attention is All You Need"
|
13 |
+
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17`.
|
14 |
+
|
15 |
+
Similar to standard `dot` attention but uses
|
16 |
+
multiple attention distributions simulataneously
|
17 |
+
to select relevant items.
|
18 |
+
|
19 |
+
.. mermaid::
|
20 |
+
|
21 |
+
graph BT
|
22 |
+
A[key]
|
23 |
+
B[value]
|
24 |
+
C[query]
|
25 |
+
O[output]
|
26 |
+
subgraph Attn
|
27 |
+
D[Attn 1]
|
28 |
+
E[Attn 2]
|
29 |
+
F[Attn N]
|
30 |
+
end
|
31 |
+
A --> D
|
32 |
+
C --> D
|
33 |
+
A --> E
|
34 |
+
C --> E
|
35 |
+
A --> F
|
36 |
+
C --> F
|
37 |
+
D --> O
|
38 |
+
E --> O
|
39 |
+
F --> O
|
40 |
+
B --> O
|
41 |
+
|
42 |
+
Also includes several additional tricks.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
head_count (int): number of parallel heads
|
46 |
+
model_dim (int): the dimension of keys/values/queries,
|
47 |
+
must be divisible by head_count
|
48 |
+
dropout (float): dropout parameter
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(self, head_count, model_dim, dropout=0.1,
|
52 |
+
max_relative_positions=0):
|
53 |
+
assert model_dim % head_count == 0
|
54 |
+
self.dim_per_head = model_dim // head_count
|
55 |
+
self.model_dim = model_dim
|
56 |
+
|
57 |
+
super(MultiHeadedAttention, self).__init__()
|
58 |
+
self.head_count = head_count
|
59 |
+
|
60 |
+
self.linear_keys = nn.Linear(model_dim,
|
61 |
+
head_count * self.dim_per_head)
|
62 |
+
self.linear_values = nn.Linear(model_dim,
|
63 |
+
head_count * self.dim_per_head)
|
64 |
+
self.linear_query = nn.Linear(model_dim,
|
65 |
+
head_count * self.dim_per_head)
|
66 |
+
self.softmax = nn.Softmax(dim=-1)
|
67 |
+
self.dropout = nn.Dropout(dropout)
|
68 |
+
self.final_linear = nn.Linear(model_dim, model_dim)
|
69 |
+
|
70 |
+
self.max_relative_positions = max_relative_positions
|
71 |
+
|
72 |
+
if max_relative_positions > 0:
|
73 |
+
vocab_size = max_relative_positions * 2 + 1
|
74 |
+
self.relative_positions_embeddings = nn.Embedding(
|
75 |
+
vocab_size, self.dim_per_head)
|
76 |
+
|
77 |
+
def forward(self, key, value, query, mask=None,
|
78 |
+
layer_cache=None, attn_type=None):
|
79 |
+
"""
|
80 |
+
Compute the context vector and the attention vectors.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
key (FloatTensor): set of `key_len`
|
84 |
+
key vectors ``(batch, key_len, dim)``
|
85 |
+
value (FloatTensor): set of `key_len`
|
86 |
+
value vectors ``(batch, key_len, dim)``
|
87 |
+
query (FloatTensor): set of `query_len`
|
88 |
+
query vectors ``(batch, query_len, dim)``
|
89 |
+
mask: binary mask 1/0 indicating which keys have
|
90 |
+
zero / non-zero attention ``(batch, query_len, key_len)``
|
91 |
+
Returns:
|
92 |
+
(FloatTensor, FloatTensor):
|
93 |
+
|
94 |
+
* output context vectors ``(batch, query_len, dim)``
|
95 |
+
* Attention vector in heads ``(batch, head, query_len, key_len)``.
|
96 |
+
"""
|
97 |
+
|
98 |
+
# CHECKS
|
99 |
+
# batch, k_len, d = key.size()
|
100 |
+
# batch_, k_len_, d_ = value.size()
|
101 |
+
# aeq(batch, batch_)
|
102 |
+
# aeq(k_len, k_len_)
|
103 |
+
# aeq(d, d_)
|
104 |
+
# batch_, q_len, d_ = query.size()
|
105 |
+
# aeq(batch, batch_)
|
106 |
+
# aeq(d, d_)
|
107 |
+
# aeq(self.model_dim % 8, 0)
|
108 |
+
# if mask is not None:
|
109 |
+
# batch_, q_len_, k_len_ = mask.size()
|
110 |
+
# aeq(batch_, batch)
|
111 |
+
# aeq(k_len_, k_len)
|
112 |
+
# aeq(q_len_ == q_len)
|
113 |
+
# END CHECKS
|
114 |
+
|
115 |
+
batch_size = key.size(0)
|
116 |
+
dim_per_head = self.dim_per_head
|
117 |
+
head_count = self.head_count
|
118 |
+
key_len = key.size(1)
|
119 |
+
query_len = query.size(1)
|
120 |
+
|
121 |
+
def shape(x):
|
122 |
+
"""Projection."""
|
123 |
+
return x.view(batch_size, -1, head_count, dim_per_head) \
|
124 |
+
.transpose(1, 2)
|
125 |
+
|
126 |
+
def unshape(x):
|
127 |
+
"""Compute context."""
|
128 |
+
return x.transpose(1, 2).contiguous() \
|
129 |
+
.view(batch_size, -1, head_count * dim_per_head)
|
130 |
+
|
131 |
+
# 1) Project key, value, and query.
|
132 |
+
if layer_cache is not None:
|
133 |
+
if attn_type == "self":
|
134 |
+
query, key, value = self.linear_query(query),\
|
135 |
+
self.linear_keys(query),\
|
136 |
+
self.linear_values(query)
|
137 |
+
key = shape(key)
|
138 |
+
value = shape(value)
|
139 |
+
if layer_cache["self_keys"] is not None:
|
140 |
+
key = torch.cat(
|
141 |
+
(layer_cache["self_keys"], key),
|
142 |
+
dim=2)
|
143 |
+
if layer_cache["self_values"] is not None:
|
144 |
+
value = torch.cat(
|
145 |
+
(layer_cache["self_values"], value),
|
146 |
+
dim=2)
|
147 |
+
layer_cache["self_keys"] = key
|
148 |
+
layer_cache["self_values"] = value
|
149 |
+
elif attn_type == "context":
|
150 |
+
query = self.linear_query(query)
|
151 |
+
if layer_cache["memory_keys"] is None:
|
152 |
+
key, value = self.linear_keys(key),\
|
153 |
+
self.linear_values(value)
|
154 |
+
key = shape(key)
|
155 |
+
value = shape(value)
|
156 |
+
else:
|
157 |
+
key, value = layer_cache["memory_keys"],\
|
158 |
+
layer_cache["memory_values"]
|
159 |
+
layer_cache["memory_keys"] = key
|
160 |
+
layer_cache["memory_values"] = value
|
161 |
+
else:
|
162 |
+
key = self.linear_keys(key)
|
163 |
+
value = self.linear_values(value)
|
164 |
+
query = self.linear_query(query)
|
165 |
+
key = shape(key)
|
166 |
+
value = shape(value)
|
167 |
+
|
168 |
+
if self.max_relative_positions > 0 and attn_type == "self":
|
169 |
+
key_len = key.size(2)
|
170 |
+
# 1 or key_len x key_len
|
171 |
+
relative_positions_matrix = generate_relative_positions_matrix(
|
172 |
+
key_len, self.max_relative_positions,
|
173 |
+
cache=True if layer_cache is not None else False)
|
174 |
+
# 1 or key_len x key_len x dim_per_head
|
175 |
+
relations_keys = self.relative_positions_embeddings(
|
176 |
+
relative_positions_matrix.to(key.device))
|
177 |
+
# 1 or key_len x key_len x dim_per_head
|
178 |
+
relations_values = self.relative_positions_embeddings(
|
179 |
+
relative_positions_matrix.to(key.device))
|
180 |
+
|
181 |
+
query = shape(query)
|
182 |
+
|
183 |
+
key_len = key.size(2)
|
184 |
+
query_len = query.size(2)
|
185 |
+
|
186 |
+
# 2) Calculate and scale scores.
|
187 |
+
query = query / math.sqrt(dim_per_head)
|
188 |
+
# batch x num_heads x query_len x key_len
|
189 |
+
query_key = torch.matmul(query, key.transpose(2, 3))
|
190 |
+
|
191 |
+
if self.max_relative_positions > 0 and attn_type == "self":
|
192 |
+
scores = query_key + relative_matmul(query, relations_keys, True)
|
193 |
+
else:
|
194 |
+
scores = query_key
|
195 |
+
scores = scores.float()
|
196 |
+
|
197 |
+
if mask is not None:
|
198 |
+
mask = mask.unsqueeze(1) # [B, 1, 1, T_values]
|
199 |
+
scores = scores.masked_fill(mask, -1e18)
|
200 |
+
|
201 |
+
# 3) Apply attention dropout and compute context vectors.
|
202 |
+
attn = self.softmax(scores).to(query.dtype)
|
203 |
+
drop_attn = self.dropout(attn)
|
204 |
+
|
205 |
+
context_original = torch.matmul(drop_attn, value)
|
206 |
+
|
207 |
+
if self.max_relative_positions > 0 and attn_type == "self":
|
208 |
+
context = unshape(context_original
|
209 |
+
+ relative_matmul(drop_attn,
|
210 |
+
relations_values,
|
211 |
+
False))
|
212 |
+
else:
|
213 |
+
context = unshape(context_original)
|
214 |
+
|
215 |
+
output = self.final_linear(context)
|
216 |
+
# CHECK
|
217 |
+
# batch_, q_len_, d_ = output.size()
|
218 |
+
# aeq(q_len, q_len_)
|
219 |
+
# aeq(batch, batch_)
|
220 |
+
# aeq(d, d_)
|
221 |
+
|
222 |
+
# Return multi-head attn
|
223 |
+
attns = attn \
|
224 |
+
.view(batch_size, head_count,
|
225 |
+
query_len, key_len)
|
226 |
+
|
227 |
+
return output, attns
|
228 |
+
|
229 |
+
def update_dropout(self, dropout):
|
230 |
+
self.dropout.p = dropout
|
onmt_modules/position_ffn.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Position feed-forward network from "Attention is All You Need"."""
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class PositionwiseFeedForward(nn.Module):
|
7 |
+
""" A two-layer Feed-Forward-Network with residual layer norm.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
d_model (int): the size of input for the first-layer of the FFN.
|
11 |
+
d_ff (int): the hidden layer size of the second-layer
|
12 |
+
of the FNN.
|
13 |
+
dropout (float): dropout probability in :math:`[0, 1)`.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, d_model, d_ff, dropout=0.1):
|
17 |
+
super(PositionwiseFeedForward, self).__init__()
|
18 |
+
self.w_1 = nn.Linear(d_model, d_ff)
|
19 |
+
self.w_2 = nn.Linear(d_ff, d_model)
|
20 |
+
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
21 |
+
self.dropout_1 = nn.Dropout(dropout)
|
22 |
+
self.relu = nn.ReLU()
|
23 |
+
self.dropout_2 = nn.Dropout(dropout)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
"""Layer definition.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
x: ``(batch_size, input_len, model_dim)``
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
(FloatTensor): Output ``(batch_size, input_len, model_dim)``.
|
33 |
+
"""
|
34 |
+
|
35 |
+
inter = self.dropout_1(self.relu(self.w_1(self.layer_norm(x))))
|
36 |
+
output = self.dropout_2(self.w_2(inter))
|
37 |
+
return output + x
|
38 |
+
|
39 |
+
def update_dropout(self, dropout):
|
40 |
+
self.dropout_1.p = dropout
|
41 |
+
self.dropout_2.p = dropout
|
override_decoder.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from onmt_modules.decoder_transformer import TransformerDecoder
|
2 |
+
from onmt_modules.misc import sequence_mask
|
3 |
+
|
4 |
+
|
5 |
+
class OnmtDecoder_1(TransformerDecoder):
|
6 |
+
# overide forward
|
7 |
+
# without teacher forcing for stop
|
8 |
+
def forward(self, tgt, memory_bank, step=None, **kwargs):
|
9 |
+
"""Decode, possibly stepwise."""
|
10 |
+
if step == 0:
|
11 |
+
self._init_cache(memory_bank)
|
12 |
+
|
13 |
+
if step is None:
|
14 |
+
tgt_lens = kwargs["tgt_lengths"]
|
15 |
+
else:
|
16 |
+
tgt_words = kwargs["tgt_words"]
|
17 |
+
|
18 |
+
emb = self.embeddings(tgt, step=step)
|
19 |
+
assert emb.dim() == 3 # len x batch x embedding_dim
|
20 |
+
|
21 |
+
output = emb.transpose(0, 1).contiguous()
|
22 |
+
src_memory_bank = memory_bank.transpose(0, 1).contiguous()
|
23 |
+
|
24 |
+
pad_idx = self.embeddings.word_padding_idx
|
25 |
+
src_lens = kwargs["memory_lengths"]
|
26 |
+
src_max_len = self.state["src"].shape[0]
|
27 |
+
src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1)
|
28 |
+
if step is None:
|
29 |
+
tgt_max_len = tgt_lens.max()
|
30 |
+
tgt_pad_mask = ~sequence_mask(tgt_lens, tgt_max_len).unsqueeze(1)
|
31 |
+
else:
|
32 |
+
tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1)
|
33 |
+
|
34 |
+
with_align = kwargs.pop('with_align', False)
|
35 |
+
attn_aligns = []
|
36 |
+
|
37 |
+
for i, layer in enumerate(self.transformer_layers):
|
38 |
+
layer_cache = self.state["cache"]["layer_{}".format(i)] \
|
39 |
+
if step is not None else None
|
40 |
+
output, attn, attn_align = layer(
|
41 |
+
output,
|
42 |
+
src_memory_bank,
|
43 |
+
src_pad_mask,
|
44 |
+
tgt_pad_mask,
|
45 |
+
layer_cache=layer_cache,
|
46 |
+
step=step,
|
47 |
+
with_align=with_align)
|
48 |
+
if attn_align is not None:
|
49 |
+
attn_aligns.append(attn_align)
|
50 |
+
|
51 |
+
output = self.layer_norm(output)
|
52 |
+
dec_outs = output.transpose(0, 1).contiguous()
|
53 |
+
attn = attn.transpose(0, 1).contiguous()
|
54 |
+
|
55 |
+
attns = {"std": attn}
|
56 |
+
if self._copy:
|
57 |
+
attns["copy"] = attn
|
58 |
+
if with_align:
|
59 |
+
attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)`
|
60 |
+
# attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg
|
61 |
+
|
62 |
+
# TODO change the way attns is returned dict => list or tuple (onnx)
|
63 |
+
return dec_outs, attns
|
prepare_train_data.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
import numpy as np
|
4 |
+
import scipy.fftpack
|
5 |
+
import soundfile as sf
|
6 |
+
from utils import pySTFT
|
7 |
+
from scipy import signal
|
8 |
+
from librosa.filters import mel
|
9 |
+
from utils import butter_highpass
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from model_sea import Generator as Model
|
14 |
+
from hparams_sea import hparams
|
15 |
+
|
16 |
+
|
17 |
+
mel_basis = mel(16000, 1024, fmin=90, fmax=7600, n_mels=80).T
|
18 |
+
min_level = np.exp(-100 / 20 * np.log(10))
|
19 |
+
b, a = butter_highpass(30, 16000, order=5)
|
20 |
+
|
21 |
+
mfcc_mean, mfcc_std, dctmx = pickle.load(open('assets/mfcc_stats.pkl', 'rb'))
|
22 |
+
spk2emb = pickle.load(open('assets/spk2emb_82.pkl', 'rb'))
|
23 |
+
|
24 |
+
rootDir = "assets/vctk16-train-wav"
|
25 |
+
targetDir_sp = 'assets/vctk16-train-sp-mel'
|
26 |
+
targetDir_cep = 'assets/vctk16-train-cep-mel'
|
27 |
+
targetDir_cd = 'assets/vctk16-train-teacher'
|
28 |
+
|
29 |
+
device = 'cuda:0'
|
30 |
+
|
31 |
+
G = Model(hparams).eval().to(device)
|
32 |
+
|
33 |
+
g_checkpoint = torch.load('assets/sea.ckpt', map_location=lambda storage, loc: storage)
|
34 |
+
G.load_state_dict(g_checkpoint['model'], strict=True)
|
35 |
+
|
36 |
+
|
37 |
+
metadata = []
|
38 |
+
dirName, subdirList, _ = next(os.walk(rootDir))
|
39 |
+
|
40 |
+
for subdir in sorted(subdirList):
|
41 |
+
print(subdir)
|
42 |
+
|
43 |
+
if not os.path.exists(os.path.join(targetDir_sp, subdir)):
|
44 |
+
os.makedirs(os.path.join(targetDir_sp, subdir))
|
45 |
+
if not os.path.exists(os.path.join(targetDir_cep, subdir)):
|
46 |
+
os.makedirs(os.path.join(targetDir_cep, subdir))
|
47 |
+
if not os.path.exists(os.path.join(targetDir_cd, subdir)):
|
48 |
+
os.makedirs(os.path.join(targetDir_cd, subdir))
|
49 |
+
|
50 |
+
submeta = []
|
51 |
+
submeta.append(subdir)
|
52 |
+
submeta.append(spk2emb[subdir])
|
53 |
+
|
54 |
+
_,_, fileList = next(os.walk(os.path.join(dirName,subdir)))
|
55 |
+
|
56 |
+
for fileName in sorted(fileList):
|
57 |
+
x, fs = sf.read(os.path.join(dirName,subdir,fileName))
|
58 |
+
if x.shape[0] % 256 == 0:
|
59 |
+
x = np.concatenate((x, np.array([1e-06])), axis=0)
|
60 |
+
y = signal.filtfilt(b, a, x)
|
61 |
+
D = pySTFT(y * 0.96).T
|
62 |
+
D_mel = np.dot(D, mel_basis)
|
63 |
+
D_db = 20 * np.log10(np.maximum(min_level, D_mel))
|
64 |
+
|
65 |
+
# mel sp
|
66 |
+
S = (D_db + 80) / 100
|
67 |
+
|
68 |
+
# mel cep
|
69 |
+
cc_tmp = S.dot(dctmx)
|
70 |
+
cc_norm = (cc_tmp - mfcc_mean) / mfcc_std
|
71 |
+
S = np.clip(S, 0, 1)
|
72 |
+
|
73 |
+
# teacher code
|
74 |
+
cc_torch = torch.from_numpy(cc_norm[:,0:20].astype(np.float32)).unsqueeze(0).to(device)
|
75 |
+
with torch.no_grad():
|
76 |
+
codes = G.encode(cc_torch, torch.ones_like(cc_torch[:,:,0])).squeeze(0)
|
77 |
+
|
78 |
+
np.save(os.path.join(targetDir_cd, subdir, fileName[:-4]),
|
79 |
+
codes.cpu().numpy(), allow_pickle=False)
|
80 |
+
np.save(os.path.join(targetDir_sp, subdir, fileName[:-4]),
|
81 |
+
S.astype(np.float32), allow_pickle=False)
|
82 |
+
np.save(os.path.join(targetDir_cep, subdir, fileName[:-4]),
|
83 |
+
cc_norm.astype(np.float32), allow_pickle=False)
|
84 |
+
|
85 |
+
submeta.append(subdir+'/'+fileName[:-4]+'.npy')
|
86 |
+
|
87 |
+
metadata.append(submeta)
|
88 |
+
|
89 |
+
with open('./assets/train_vctk.meta', 'wb') as handle:
|
90 |
+
pickle.dump(metadata, handle)
|
solver_1.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import pickle
|
4 |
+
import datetime
|
5 |
+
import itertools
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from onmt_modules.misc import sequence_mask
|
11 |
+
from model_autopst import Generator_1 as Predictor
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
class Solver(object):
|
16 |
+
|
17 |
+
def __init__(self, data_loader, config, hparams):
|
18 |
+
"""Initialize configurations."""
|
19 |
+
|
20 |
+
|
21 |
+
self.data_loader = data_loader
|
22 |
+
self.hparams = hparams
|
23 |
+
self.gate_threshold = hparams.gate_threshold
|
24 |
+
|
25 |
+
self.use_cuda = torch.cuda.is_available()
|
26 |
+
self.device = torch.device('cuda:{}'.format(config.device_id) if self.use_cuda else 'cpu')
|
27 |
+
self.num_iters = config.num_iters
|
28 |
+
self.log_step = config.log_step
|
29 |
+
|
30 |
+
# Build the model
|
31 |
+
self.build_model()
|
32 |
+
|
33 |
+
|
34 |
+
def build_model(self):
|
35 |
+
|
36 |
+
self.P = Predictor(self.hparams)
|
37 |
+
|
38 |
+
self.optimizer = torch.optim.Adam(self.P.parameters(), 0.0001, [0.9, 0.999])
|
39 |
+
|
40 |
+
self.P.to(self.device)
|
41 |
+
|
42 |
+
self.BCELoss = torch.nn.BCEWithLogitsLoss().to(self.device)
|
43 |
+
|
44 |
+
|
45 |
+
def train(self):
|
46 |
+
# Set data loader
|
47 |
+
data_loader = self.data_loader
|
48 |
+
data_iter = iter(data_loader)
|
49 |
+
|
50 |
+
|
51 |
+
# Print logs in specified order
|
52 |
+
keys = ['P/loss_tx2sp', 'P/loss_stop_sp']
|
53 |
+
|
54 |
+
|
55 |
+
# Start training.
|
56 |
+
print('Start training...')
|
57 |
+
start_time = time.time()
|
58 |
+
for i in range(self.num_iters):
|
59 |
+
|
60 |
+
try:
|
61 |
+
sp_real, cep_real, cd_real, _, num_rep_sync, len_real, _, len_short_sync, spk_emb = next(data_iter)
|
62 |
+
except:
|
63 |
+
data_iter = iter(data_loader)
|
64 |
+
sp_real, cep_real, cd_real, _, num_rep_sync, len_real, _, len_short_sync, spk_emb = next(data_iter)
|
65 |
+
|
66 |
+
|
67 |
+
sp_real = sp_real.to(self.device)
|
68 |
+
cep_real = cep_real.to(self.device)
|
69 |
+
cd_real = cd_real.to(self.device)
|
70 |
+
len_real = len_real.to(self.device)
|
71 |
+
spk_emb = spk_emb.to(self.device)
|
72 |
+
num_rep_sync = num_rep_sync.to(self.device)
|
73 |
+
len_short_sync = len_short_sync.to(self.device)
|
74 |
+
|
75 |
+
|
76 |
+
# real spect masks
|
77 |
+
mask_sp_real = ~sequence_mask(len_real, sp_real.size(1))
|
78 |
+
mask_long = (~mask_sp_real).float()
|
79 |
+
|
80 |
+
len_real_mask = torch.min(len_real + 10,
|
81 |
+
torch.full_like(len_real, sp_real.size(1)))
|
82 |
+
loss_tx2sp_mask = sequence_mask(len_real_mask, sp_real.size(1)).float().unsqueeze(-1)
|
83 |
+
|
84 |
+
# text input masks
|
85 |
+
codes_mask = sequence_mask(len_short_sync, num_rep_sync.size(1)).float()
|
86 |
+
|
87 |
+
|
88 |
+
# =================================================================================== #
|
89 |
+
# 2. Train #
|
90 |
+
# =================================================================================== #
|
91 |
+
|
92 |
+
self.P = self.P.train()
|
93 |
+
|
94 |
+
|
95 |
+
sp_real_sft = torch.zeros_like(sp_real)
|
96 |
+
sp_real_sft[:, 1:, :] = sp_real[:, :-1, :]
|
97 |
+
|
98 |
+
|
99 |
+
spect_pred, stop_pred_sp = self.P(cep_real.transpose(2,1),
|
100 |
+
mask_long,
|
101 |
+
codes_mask,
|
102 |
+
num_rep_sync,
|
103 |
+
len_short_sync+1,
|
104 |
+
sp_real_sft.transpose(1,0),
|
105 |
+
len_real+1,
|
106 |
+
spk_emb)
|
107 |
+
|
108 |
+
|
109 |
+
loss_tx2sp = (F.mse_loss(spect_pred.permute(1,0,2), sp_real, reduction='none')
|
110 |
+
* loss_tx2sp_mask).sum() / loss_tx2sp_mask.sum()
|
111 |
+
|
112 |
+
loss_stop_sp = self.BCELoss(stop_pred_sp.squeeze(-1).t(), mask_sp_real.float())
|
113 |
+
|
114 |
+
loss_total = loss_tx2sp + loss_stop_sp
|
115 |
+
|
116 |
+
# Backward and optimize
|
117 |
+
self.optimizer.zero_grad()
|
118 |
+
loss_total.backward()
|
119 |
+
self.optimizer.step()
|
120 |
+
|
121 |
+
|
122 |
+
# Logging
|
123 |
+
loss = {}
|
124 |
+
loss['P/loss_tx2sp'] = loss_tx2sp.item()
|
125 |
+
loss['P/loss_stop_sp'] = loss_stop_sp.item()
|
126 |
+
|
127 |
+
|
128 |
+
# =================================================================================== #
|
129 |
+
# 4. Miscellaneous #
|
130 |
+
# =================================================================================== #
|
131 |
+
|
132 |
+
# Print out training information
|
133 |
+
if (i+1) % self.log_step == 0:
|
134 |
+
et = time.time() - start_time
|
135 |
+
et = str(datetime.timedelta(seconds=et))[:-7]
|
136 |
+
log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
|
137 |
+
for tag in keys:
|
138 |
+
log += ", {}: {:.8f}".format(tag, loss[tag])
|
139 |
+
print(log)
|
140 |
+
|
141 |
+
|
142 |
+
# Save model checkpoints.
|
143 |
+
if (i+1) % 10000 == 0:
|
144 |
+
torch.save({'model': self.P.state_dict(),
|
145 |
+
'optimizer': self.optimizer.state_dict()}, f'./assets/{i+1}-A.ckpt')
|
146 |
+
print('Saved model checkpoints into assets ...')
|
solver_2.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import pickle
|
4 |
+
import datetime
|
5 |
+
import itertools
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from onmt_modules.misc import sequence_mask
|
11 |
+
from model_autopst import Generator_2 as Predictor
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
class Solver(object):
|
16 |
+
|
17 |
+
def __init__(self, data_loader, config, hparams):
|
18 |
+
"""Initialize configurations."""
|
19 |
+
|
20 |
+
|
21 |
+
self.data_loader = data_loader
|
22 |
+
self.hparams = hparams
|
23 |
+
self.gate_threshold = hparams.gate_threshold
|
24 |
+
|
25 |
+
self.use_cuda = torch.cuda.is_available()
|
26 |
+
self.device = torch.device('cuda:{}'.format(config.device_id) if self.use_cuda else 'cpu')
|
27 |
+
self.num_iters = config.num_iters
|
28 |
+
self.log_step = config.log_step
|
29 |
+
|
30 |
+
# Build the model
|
31 |
+
self.build_model()
|
32 |
+
|
33 |
+
|
34 |
+
def build_model(self):
|
35 |
+
|
36 |
+
self.P = Predictor(self.hparams)
|
37 |
+
self.freeze_layers(self.P.encoder_cd)
|
38 |
+
|
39 |
+
self.optimizer = torch.optim.Adam(self.P.parameters(), 0.0001, [0.9, 0.999])
|
40 |
+
|
41 |
+
self.P.to(self.device)
|
42 |
+
|
43 |
+
self.BCELoss = torch.nn.BCEWithLogitsLoss().to(self.device)
|
44 |
+
|
45 |
+
|
46 |
+
checkpoint = torch.load(self.hparams.pretrained_path,
|
47 |
+
map_location=lambda storage, loc: storage)
|
48 |
+
|
49 |
+
self.P.load_state_dict(checkpoint['model'], strict=True)
|
50 |
+
print('Loaded pretrained encoder .........................................')
|
51 |
+
|
52 |
+
|
53 |
+
def freeze_layers(self, layer):
|
54 |
+
print('Fixing layers!')
|
55 |
+
for param in layer.parameters():
|
56 |
+
param.requires_grad = False
|
57 |
+
|
58 |
+
|
59 |
+
def train(self):
|
60 |
+
# Set data loader
|
61 |
+
data_loader = self.data_loader
|
62 |
+
data_iter = iter(data_loader)
|
63 |
+
|
64 |
+
|
65 |
+
# Print logs in specified order
|
66 |
+
keys = ['P/loss_tx2sp', 'P/loss_stop_sp']
|
67 |
+
|
68 |
+
|
69 |
+
# Start training.
|
70 |
+
print('Start training...')
|
71 |
+
start_time = time.time()
|
72 |
+
for i in range(self.num_iters):
|
73 |
+
|
74 |
+
try:
|
75 |
+
sp_real, cep_real, cd_real, num_rep, _, len_real, len_short, _, spk_emb = next(data_iter)
|
76 |
+
except:
|
77 |
+
data_iter = iter(data_loader)
|
78 |
+
sp_real, cep_real, cd_real, num_rep, _, len_real, len_short, _, spk_emb = next(data_iter)
|
79 |
+
|
80 |
+
|
81 |
+
sp_real = sp_real.to(self.device)
|
82 |
+
cep_real = cep_real.to(self.device)
|
83 |
+
cd_real = cd_real.to(self.device)
|
84 |
+
len_real = len_real.to(self.device)
|
85 |
+
spk_emb = spk_emb.to(self.device)
|
86 |
+
num_rep = num_rep.to(self.device)
|
87 |
+
len_short = len_short.to(self.device)
|
88 |
+
|
89 |
+
|
90 |
+
# real spect masks
|
91 |
+
mask_sp_real = ~sequence_mask(len_real, sp_real.size(1))
|
92 |
+
mask_long = (~mask_sp_real).float()
|
93 |
+
|
94 |
+
len_real_mask = torch.min(len_real + 10,
|
95 |
+
torch.full_like(len_real, sp_real.size(1)))
|
96 |
+
loss_tx2sp_mask = sequence_mask(len_real_mask, sp_real.size(1)).float().unsqueeze(-1)
|
97 |
+
|
98 |
+
# text input masks
|
99 |
+
codes_mask = sequence_mask(len_short, num_rep.size(1)).float()
|
100 |
+
|
101 |
+
|
102 |
+
# =================================================================================== #
|
103 |
+
# 2. Train #
|
104 |
+
# =================================================================================== #
|
105 |
+
|
106 |
+
self.P = self.P.train()
|
107 |
+
|
108 |
+
|
109 |
+
sp_real_sft = torch.zeros_like(sp_real)
|
110 |
+
sp_real_sft[:, 1:, :] = sp_real[:, :-1, :]
|
111 |
+
|
112 |
+
|
113 |
+
spect_pred, stop_pred_sp = self.P(cep_real.transpose(2,1),
|
114 |
+
mask_long,
|
115 |
+
codes_mask,
|
116 |
+
num_rep,
|
117 |
+
len_short+1,
|
118 |
+
sp_real_sft.transpose(1,0),
|
119 |
+
len_real+1,
|
120 |
+
spk_emb)
|
121 |
+
|
122 |
+
|
123 |
+
loss_tx2sp = (F.mse_loss(spect_pred.permute(1,0,2), sp_real, reduction='none')
|
124 |
+
* loss_tx2sp_mask).sum() / loss_tx2sp_mask.sum()
|
125 |
+
|
126 |
+
loss_stop_sp = self.BCELoss(stop_pred_sp.squeeze(-1).t(), mask_sp_real.float())
|
127 |
+
|
128 |
+
loss_total = loss_tx2sp + loss_stop_sp
|
129 |
+
|
130 |
+
# Backward and optimize
|
131 |
+
self.optimizer.zero_grad()
|
132 |
+
loss_total.backward()
|
133 |
+
self.optimizer.step()
|
134 |
+
|
135 |
+
|
136 |
+
# Logging
|
137 |
+
loss = {}
|
138 |
+
loss['P/loss_tx2sp'] = loss_tx2sp.item()
|
139 |
+
loss['P/loss_stop_sp'] = loss_stop_sp.item()
|
140 |
+
|
141 |
+
|
142 |
+
# =================================================================================== #
|
143 |
+
# 4. Miscellaneous #
|
144 |
+
# =================================================================================== #
|
145 |
+
|
146 |
+
# Print out training information
|
147 |
+
if (i+1) % self.log_step == 0:
|
148 |
+
et = time.time() - start_time
|
149 |
+
et = str(datetime.timedelta(seconds=et))[:-7]
|
150 |
+
log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
|
151 |
+
for tag in keys:
|
152 |
+
log += ", {}: {:.8f}".format(tag, loss[tag])
|
153 |
+
print(log)
|
154 |
+
|
155 |
+
|
156 |
+
# Save model checkpoints.
|
157 |
+
if (i+1) % 10000 == 0:
|
158 |
+
torch.save({'model': self.P.state_dict(),
|
159 |
+
'optimizer': self.optimizer.state_dict()}, f'./assets/{i+1}-B.ckpt')
|
160 |
+
print('Saved model checkpoints into assets ...')
|
tfcompat/__init__.py
ADDED
File without changes
|
tfcompat/hparam.py
ADDED
@@ -0,0 +1,726 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Hyperparameter values."""
|
16 |
+
from __future__ import absolute_import
|
17 |
+
from __future__ import division
|
18 |
+
from __future__ import print_function
|
19 |
+
|
20 |
+
import json
|
21 |
+
import numbers
|
22 |
+
import re
|
23 |
+
|
24 |
+
import six
|
25 |
+
|
26 |
+
## from tensorflow.contrib.training.python.training import hparam_pb2
|
27 |
+
## from tensorflow.python.framework import ops
|
28 |
+
## from tensorflow.python.util import compat
|
29 |
+
## from tensorflow.python.util import deprecation
|
30 |
+
|
31 |
+
# Define the regular expression for parsing a single clause of the input
|
32 |
+
# (delimited by commas). A legal clause looks like:
|
33 |
+
# <variable name>[<index>]? = <rhs>
|
34 |
+
# where <rhs> is either a single token or [] enclosed list of tokens.
|
35 |
+
# For example: "var[1] = a" or "x = [1,2,3]"
|
36 |
+
PARAM_RE = re.compile(r"""
|
37 |
+
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
|
38 |
+
(\[\s*(?P<index>\d+)\s*\])? # (optional) index: "1" or None
|
39 |
+
\s*=\s*
|
40 |
+
((?P<val>[^,\[]*) # single value: "a" or None
|
41 |
+
|
|
42 |
+
\[(?P<vals>[^\]]*)\]) # list of values: None or "1,2,3"
|
43 |
+
($|,\s*)""", re.VERBOSE)
|
44 |
+
|
45 |
+
|
46 |
+
def _parse_fail(name, var_type, value, values):
|
47 |
+
"""Helper function for raising a value error for bad assignment."""
|
48 |
+
raise ValueError(
|
49 |
+
'Could not parse hparam \'%s\' of type \'%s\' with value \'%s\' in %s' %
|
50 |
+
(name, var_type.__name__, value, values))
|
51 |
+
|
52 |
+
|
53 |
+
def _reuse_fail(name, values):
|
54 |
+
"""Helper function for raising a value error for reuse of name."""
|
55 |
+
raise ValueError('Multiple assignments to variable \'%s\' in %s' % (name,
|
56 |
+
values))
|
57 |
+
|
58 |
+
|
59 |
+
def _process_scalar_value(name, parse_fn, var_type, m_dict, values,
|
60 |
+
results_dictionary):
|
61 |
+
"""Update results_dictionary with a scalar value.
|
62 |
+
|
63 |
+
Used to update the results_dictionary to be returned by parse_values when
|
64 |
+
encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".)
|
65 |
+
|
66 |
+
Mutates results_dictionary.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
name: Name of variable in assignment ("s" or "arr").
|
70 |
+
parse_fn: Function for parsing the actual value.
|
71 |
+
var_type: Type of named variable.
|
72 |
+
m_dict: Dictionary constructed from regex parsing.
|
73 |
+
m_dict['val']: RHS value (scalar)
|
74 |
+
m_dict['index']: List index value (or None)
|
75 |
+
values: Full expression being parsed
|
76 |
+
results_dictionary: The dictionary being updated for return by the parsing
|
77 |
+
function.
|
78 |
+
|
79 |
+
Raises:
|
80 |
+
ValueError: If the name has already been used.
|
81 |
+
"""
|
82 |
+
try:
|
83 |
+
parsed_value = parse_fn(m_dict['val'])
|
84 |
+
except ValueError:
|
85 |
+
_parse_fail(name, var_type, m_dict['val'], values)
|
86 |
+
|
87 |
+
# If no index is provided
|
88 |
+
if not m_dict['index']:
|
89 |
+
if name in results_dictionary:
|
90 |
+
_reuse_fail(name, values)
|
91 |
+
results_dictionary[name] = parsed_value
|
92 |
+
else:
|
93 |
+
if name in results_dictionary:
|
94 |
+
# The name has already been used as a scalar, then it
|
95 |
+
# will be in this dictionary and map to a non-dictionary.
|
96 |
+
if not isinstance(results_dictionary.get(name), dict):
|
97 |
+
_reuse_fail(name, values)
|
98 |
+
else:
|
99 |
+
results_dictionary[name] = {}
|
100 |
+
|
101 |
+
index = int(m_dict['index'])
|
102 |
+
# Make sure the index position hasn't already been assigned a value.
|
103 |
+
if index in results_dictionary[name]:
|
104 |
+
_reuse_fail('{}[{}]'.format(name, index), values)
|
105 |
+
results_dictionary[name][index] = parsed_value
|
106 |
+
|
107 |
+
|
108 |
+
def _process_list_value(name, parse_fn, var_type, m_dict, values,
|
109 |
+
results_dictionary):
|
110 |
+
"""Update results_dictionary from a list of values.
|
111 |
+
|
112 |
+
Used to update results_dictionary to be returned by parse_values when
|
113 |
+
encountering a clause with a list RHS (e.g. "arr=[1,2,3]".)
|
114 |
+
|
115 |
+
Mutates results_dictionary.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
name: Name of variable in assignment ("arr").
|
119 |
+
parse_fn: Function for parsing individual values.
|
120 |
+
var_type: Type of named variable.
|
121 |
+
m_dict: Dictionary constructed from regex parsing.
|
122 |
+
m_dict['val']: RHS value (scalar)
|
123 |
+
values: Full expression being parsed
|
124 |
+
results_dictionary: The dictionary being updated for return by the parsing
|
125 |
+
function.
|
126 |
+
|
127 |
+
Raises:
|
128 |
+
ValueError: If the name has an index or the values cannot be parsed.
|
129 |
+
"""
|
130 |
+
if m_dict['index'] is not None:
|
131 |
+
raise ValueError('Assignment of a list to a list index.')
|
132 |
+
elements = filter(None, re.split('[ ,]', m_dict['vals']))
|
133 |
+
# Make sure the name hasn't already been assigned a value
|
134 |
+
if name in results_dictionary:
|
135 |
+
raise _reuse_fail(name, values)
|
136 |
+
try:
|
137 |
+
results_dictionary[name] = [parse_fn(e) for e in elements]
|
138 |
+
except ValueError:
|
139 |
+
_parse_fail(name, var_type, m_dict['vals'], values)
|
140 |
+
|
141 |
+
|
142 |
+
def _cast_to_type_if_compatible(name, param_type, value):
|
143 |
+
"""Cast hparam to the provided type, if compatible.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
name: Name of the hparam to be cast.
|
147 |
+
param_type: The type of the hparam.
|
148 |
+
value: The value to be cast, if compatible.
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
The result of casting `value` to `param_type`.
|
152 |
+
|
153 |
+
Raises:
|
154 |
+
ValueError: If the type of `value` is not compatible with param_type.
|
155 |
+
* If `param_type` is a string type, but `value` is not.
|
156 |
+
* If `param_type` is a boolean, but `value` is not, or vice versa.
|
157 |
+
* If `param_type` is an integer type, but `value` is not.
|
158 |
+
* If `param_type` is a float type, but `value` is not a numeric type.
|
159 |
+
"""
|
160 |
+
fail_msg = (
|
161 |
+
"Could not cast hparam '%s' of type '%s' from value %r" %
|
162 |
+
(name, param_type, value))
|
163 |
+
|
164 |
+
# Some callers use None, for which we can't do any casting/checking. :(
|
165 |
+
if issubclass(param_type, type(None)):
|
166 |
+
return value
|
167 |
+
|
168 |
+
# Avoid converting a non-string type to a string.
|
169 |
+
if (issubclass(param_type, (six.string_types, six.binary_type)) and
|
170 |
+
not isinstance(value, (six.string_types, six.binary_type))):
|
171 |
+
raise ValueError(fail_msg)
|
172 |
+
|
173 |
+
# Avoid converting a number or string type to a boolean or vice versa.
|
174 |
+
if issubclass(param_type, bool) != isinstance(value, bool):
|
175 |
+
raise ValueError(fail_msg)
|
176 |
+
|
177 |
+
# Avoid converting float to an integer (the reverse is fine).
|
178 |
+
if (issubclass(param_type, numbers.Integral) and
|
179 |
+
not isinstance(value, numbers.Integral)):
|
180 |
+
raise ValueError(fail_msg)
|
181 |
+
|
182 |
+
# Avoid converting a non-numeric type to a numeric type.
|
183 |
+
if (issubclass(param_type, numbers.Number) and
|
184 |
+
not isinstance(value, numbers.Number)):
|
185 |
+
raise ValueError(fail_msg)
|
186 |
+
|
187 |
+
return param_type(value)
|
188 |
+
|
189 |
+
|
190 |
+
def parse_values(values, type_map):
|
191 |
+
"""Parses hyperparameter values from a string into a python map.
|
192 |
+
|
193 |
+
`values` is a string containing comma-separated `name=value` pairs.
|
194 |
+
For each pair, the value of the hyperparameter named `name` is set to
|
195 |
+
`value`.
|
196 |
+
|
197 |
+
If a hyperparameter name appears multiple times in `values`, a ValueError
|
198 |
+
is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2').
|
199 |
+
|
200 |
+
If a hyperparameter name in both an index assignment and scalar assignment,
|
201 |
+
a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1').
|
202 |
+
|
203 |
+
The hyperparameter name may contain '.' symbols, which will result in an
|
204 |
+
attribute name that is only accessible through the getattr and setattr
|
205 |
+
functions. (And must be first explicit added through add_hparam.)
|
206 |
+
|
207 |
+
WARNING: Use of '.' in your variable names is allowed, but is not well
|
208 |
+
supported and not recommended.
|
209 |
+
|
210 |
+
The `value` in `name=value` must follows the syntax according to the
|
211 |
+
type of the parameter:
|
212 |
+
|
213 |
+
* Scalar integer: A Python-parsable integer point value. E.g.: 1,
|
214 |
+
100, -12.
|
215 |
+
* Scalar float: A Python-parsable floating point value. E.g.: 1.0,
|
216 |
+
-.54e89.
|
217 |
+
* Boolean: Either true or false.
|
218 |
+
* Scalar string: A non-empty sequence of characters, excluding comma,
|
219 |
+
spaces, and square brackets. E.g.: foo, bar_1.
|
220 |
+
* List: A comma separated list of scalar values of the parameter type
|
221 |
+
enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low].
|
222 |
+
|
223 |
+
When index assignment is used, the corresponding type_map key should be the
|
224 |
+
list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not
|
225 |
+
"arr[1]").
|
226 |
+
|
227 |
+
Args:
|
228 |
+
values: String. Comma separated list of `name=value` pairs where
|
229 |
+
'value' must follow the syntax described above.
|
230 |
+
type_map: A dictionary mapping hyperparameter names to types. Note every
|
231 |
+
parameter name in values must be a key in type_map. The values must
|
232 |
+
conform to the types indicated, where a value V is said to conform to a
|
233 |
+
type T if either V has type T, or V is a list of elements of type T.
|
234 |
+
Hence, for a multidimensional parameter 'x' taking float values,
|
235 |
+
'x=[0.1,0.2]' will parse successfully if type_map['x'] = float.
|
236 |
+
|
237 |
+
Returns:
|
238 |
+
A python map mapping each name to either:
|
239 |
+
* A scalar value.
|
240 |
+
* A list of scalar values.
|
241 |
+
* A dictionary mapping index numbers to scalar values.
|
242 |
+
(e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}")
|
243 |
+
|
244 |
+
Raises:
|
245 |
+
ValueError: If there is a problem with input.
|
246 |
+
* If `values` cannot be parsed.
|
247 |
+
* If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]').
|
248 |
+
* If the same rvalue is assigned two different values (e.g. 'a=1,a=2',
|
249 |
+
'a[1]=1,a[1]=2', or 'a=1,a=[1]')
|
250 |
+
"""
|
251 |
+
results_dictionary = {}
|
252 |
+
pos = 0
|
253 |
+
while pos < len(values):
|
254 |
+
m = PARAM_RE.match(values, pos)
|
255 |
+
if not m:
|
256 |
+
raise ValueError('Malformed hyperparameter value: %s' % values[pos:])
|
257 |
+
# Check that there is a comma between parameters and move past it.
|
258 |
+
pos = m.end()
|
259 |
+
# Parse the values.
|
260 |
+
m_dict = m.groupdict()
|
261 |
+
name = m_dict['name']
|
262 |
+
if name not in type_map:
|
263 |
+
raise ValueError('Unknown hyperparameter type for %s' % name)
|
264 |
+
type_ = type_map[name]
|
265 |
+
|
266 |
+
# Set up correct parsing function (depending on whether type_ is a bool)
|
267 |
+
if type_ == bool:
|
268 |
+
|
269 |
+
def parse_bool(value):
|
270 |
+
if value in ['true', 'True']:
|
271 |
+
return True
|
272 |
+
elif value in ['false', 'False']:
|
273 |
+
return False
|
274 |
+
else:
|
275 |
+
try:
|
276 |
+
return bool(int(value))
|
277 |
+
except ValueError:
|
278 |
+
_parse_fail(name, type_, value, values)
|
279 |
+
|
280 |
+
parse = parse_bool
|
281 |
+
else:
|
282 |
+
parse = type_
|
283 |
+
|
284 |
+
# If a singe value is provided
|
285 |
+
if m_dict['val'] is not None:
|
286 |
+
_process_scalar_value(name, parse, type_, m_dict, values,
|
287 |
+
results_dictionary)
|
288 |
+
|
289 |
+
# If the assigned value is a list:
|
290 |
+
elif m_dict['vals'] is not None:
|
291 |
+
_process_list_value(name, parse, type_, m_dict, values,
|
292 |
+
results_dictionary)
|
293 |
+
|
294 |
+
else: # Not assigned a list or value
|
295 |
+
_parse_fail(name, type_, '', values)
|
296 |
+
|
297 |
+
return results_dictionary
|
298 |
+
|
299 |
+
|
300 |
+
class HParams(object):
|
301 |
+
"""Class to hold a set of hyperparameters as name-value pairs.
|
302 |
+
|
303 |
+
A `HParams` object holds hyperparameters used to build and train a model,
|
304 |
+
such as the number of hidden units in a neural net layer or the learning rate
|
305 |
+
to use when training.
|
306 |
+
|
307 |
+
You first create a `HParams` object by specifying the names and values of the
|
308 |
+
hyperparameters.
|
309 |
+
|
310 |
+
To make them easily accessible the parameter names are added as direct
|
311 |
+
attributes of the class. A typical usage is as follows:
|
312 |
+
|
313 |
+
```python
|
314 |
+
# Create a HParams object specifying names and values of the model
|
315 |
+
# hyperparameters:
|
316 |
+
hparams = HParams(learning_rate=0.1, num_hidden_units=100)
|
317 |
+
|
318 |
+
# The hyperparameter are available as attributes of the HParams object:
|
319 |
+
hparams.learning_rate ==> 0.1
|
320 |
+
hparams.num_hidden_units ==> 100
|
321 |
+
```
|
322 |
+
|
323 |
+
Hyperparameters have type, which is inferred from the type of their value
|
324 |
+
passed at construction type. The currently supported types are: integer,
|
325 |
+
float, boolean, string, and list of integer, float, boolean, or string.
|
326 |
+
|
327 |
+
You can override hyperparameter values by calling the
|
328 |
+
[`parse()`](#HParams.parse) method, passing a string of comma separated
|
329 |
+
`name=value` pairs. This is intended to make it possible to override
|
330 |
+
any hyperparameter values from a single command-line flag to which
|
331 |
+
the user passes 'hyper-param=value' pairs. It avoids having to define
|
332 |
+
one flag for each hyperparameter.
|
333 |
+
|
334 |
+
The syntax expected for each value depends on the type of the parameter.
|
335 |
+
See `parse()` for a description of the syntax.
|
336 |
+
|
337 |
+
Example:
|
338 |
+
|
339 |
+
```python
|
340 |
+
# Define a command line flag to pass name=value pairs.
|
341 |
+
# For example using argparse:
|
342 |
+
import argparse
|
343 |
+
parser = argparse.ArgumentParser(description='Train my model.')
|
344 |
+
parser.add_argument('--hparams', type=str,
|
345 |
+
help='Comma separated list of "name=value" pairs.')
|
346 |
+
args = parser.parse_args()
|
347 |
+
...
|
348 |
+
def my_program():
|
349 |
+
# Create a HParams object specifying the names and values of the
|
350 |
+
# model hyperparameters:
|
351 |
+
hparams = tf.HParams(learning_rate=0.1, num_hidden_units=100,
|
352 |
+
activations=['relu', 'tanh'])
|
353 |
+
|
354 |
+
# Override hyperparameters values by parsing the command line
|
355 |
+
hparams.parse(args.hparams)
|
356 |
+
|
357 |
+
# If the user passed `--hparams=learning_rate=0.3` on the command line
|
358 |
+
# then 'hparams' has the following attributes:
|
359 |
+
hparams.learning_rate ==> 0.3
|
360 |
+
hparams.num_hidden_units ==> 100
|
361 |
+
hparams.activations ==> ['relu', 'tanh']
|
362 |
+
|
363 |
+
# If the hyperparameters are in json format use parse_json:
|
364 |
+
hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}')
|
365 |
+
```
|
366 |
+
"""
|
367 |
+
|
368 |
+
_HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks.
|
369 |
+
|
370 |
+
def __init__(self, hparam_def=None, model_structure=None, **kwargs):
|
371 |
+
"""Create an instance of `HParams` from keyword arguments.
|
372 |
+
|
373 |
+
The keyword arguments specify name-values pairs for the hyperparameters.
|
374 |
+
The parameter types are inferred from the type of the values passed.
|
375 |
+
|
376 |
+
The parameter names are added as attributes of `HParams` object, so they
|
377 |
+
can be accessed directly with the dot notation `hparams._name_`.
|
378 |
+
|
379 |
+
Example:
|
380 |
+
|
381 |
+
```python
|
382 |
+
# Define 3 hyperparameters: 'learning_rate' is a float parameter,
|
383 |
+
# 'num_hidden_units' an integer parameter, and 'activation' a string
|
384 |
+
# parameter.
|
385 |
+
hparams = tf.HParams(
|
386 |
+
learning_rate=0.1, num_hidden_units=100, activation='relu')
|
387 |
+
|
388 |
+
hparams.activation ==> 'relu'
|
389 |
+
```
|
390 |
+
|
391 |
+
Note that a few names are reserved and cannot be used as hyperparameter
|
392 |
+
names. If you use one of the reserved name the constructor raises a
|
393 |
+
`ValueError`.
|
394 |
+
|
395 |
+
Args:
|
396 |
+
hparam_def: Serialized hyperparameters, encoded as a hparam_pb2.HParamDef
|
397 |
+
protocol buffer. If provided, this object is initialized by
|
398 |
+
deserializing hparam_def. Otherwise **kwargs is used.
|
399 |
+
model_structure: An instance of ModelStructure, defining the feature
|
400 |
+
crosses to be used in the Trial.
|
401 |
+
**kwargs: Key-value pairs where the key is the hyperparameter name and
|
402 |
+
the value is the value for the parameter.
|
403 |
+
|
404 |
+
Raises:
|
405 |
+
ValueError: If both `hparam_def` and initialization values are provided,
|
406 |
+
or if one of the arguments is invalid.
|
407 |
+
|
408 |
+
"""
|
409 |
+
# Register the hyperparameters and their type in _hparam_types.
|
410 |
+
# This simplifies the implementation of parse().
|
411 |
+
# _hparam_types maps the parameter name to a tuple (type, bool).
|
412 |
+
# The type value is the type of the parameter for scalar hyperparameters,
|
413 |
+
# or the type of the list elements for multidimensional hyperparameters.
|
414 |
+
# The bool value is True if the value is a list, False otherwise.
|
415 |
+
self._hparam_types = {}
|
416 |
+
self._model_structure = model_structure
|
417 |
+
if hparam_def:
|
418 |
+
## self._init_from_proto(hparam_def)
|
419 |
+
## if kwargs:
|
420 |
+
## raise ValueError('hparam_def and initialization values are '
|
421 |
+
## 'mutually exclusive')
|
422 |
+
raise ValueError('hparam_def has been disabled in this version')
|
423 |
+
else:
|
424 |
+
for name, value in six.iteritems(kwargs):
|
425 |
+
self.add_hparam(name, value)
|
426 |
+
|
427 |
+
## def _init_from_proto(self, hparam_def):
|
428 |
+
## """Creates a new HParams from `HParamDef` protocol buffer.
|
429 |
+
##
|
430 |
+
## Args:
|
431 |
+
## hparam_def: `HParamDef` protocol buffer.
|
432 |
+
## """
|
433 |
+
## assert isinstance(hparam_def, hparam_pb2.HParamDef)
|
434 |
+
## for name, value in hparam_def.hparam.items():
|
435 |
+
## kind = value.WhichOneof('kind')
|
436 |
+
## if kind.endswith('_value'):
|
437 |
+
## # Single value.
|
438 |
+
## if kind.startswith('int64'):
|
439 |
+
## # Setting attribute value to be 'int' to ensure the type is compatible
|
440 |
+
## # with both Python2 and Python3.
|
441 |
+
## self.add_hparam(name, int(getattr(value, kind)))
|
442 |
+
## elif kind.startswith('bytes'):
|
443 |
+
## # Setting attribute value to be 'str' to ensure the type is compatible
|
444 |
+
## # with both Python2 and Python3. UTF-8 encoding is assumed.
|
445 |
+
## self.add_hparam(name, compat.as_str(getattr(value, kind)))
|
446 |
+
## else:
|
447 |
+
## self.add_hparam(name, getattr(value, kind))
|
448 |
+
## else:
|
449 |
+
## # List of values.
|
450 |
+
## if kind.startswith('int64'):
|
451 |
+
## # Setting attribute value to be 'int' to ensure the type is compatible
|
452 |
+
## # with both Python2 and Python3.
|
453 |
+
## self.add_hparam(name, [int(v) for v in getattr(value, kind).value])
|
454 |
+
## elif kind.startswith('bytes'):
|
455 |
+
## # Setting attribute value to be 'str' to ensure the type is compatible
|
456 |
+
## # with both Python2 and Python3. UTF-8 encoding is assumed.
|
457 |
+
## self.add_hparam(
|
458 |
+
## name, [compat.as_str(v) for v in getattr(value, kind).value])
|
459 |
+
## else:
|
460 |
+
## self.add_hparam(name, [v for v in getattr(value, kind).value])
|
461 |
+
|
462 |
+
def add_hparam(self, name, value):
|
463 |
+
"""Adds {name, value} pair to hyperparameters.
|
464 |
+
|
465 |
+
Args:
|
466 |
+
name: Name of the hyperparameter.
|
467 |
+
value: Value of the hyperparameter. Can be one of the following types:
|
468 |
+
int, float, string, int list, float list, or string list.
|
469 |
+
|
470 |
+
Raises:
|
471 |
+
ValueError: if one of the arguments is invalid.
|
472 |
+
"""
|
473 |
+
# Keys in kwargs are unique, but 'name' could the name of a pre-existing
|
474 |
+
# attribute of this object. In that case we refuse to use it as a
|
475 |
+
# hyperparameter name.
|
476 |
+
if getattr(self, name, None) is not None:
|
477 |
+
raise ValueError('Hyperparameter name is reserved: %s' % name)
|
478 |
+
if isinstance(value, (list, tuple)):
|
479 |
+
if not value:
|
480 |
+
raise ValueError(
|
481 |
+
'Multi-valued hyperparameters cannot be empty: %s' % name)
|
482 |
+
self._hparam_types[name] = (type(value[0]), True)
|
483 |
+
else:
|
484 |
+
self._hparam_types[name] = (type(value), False)
|
485 |
+
setattr(self, name, value)
|
486 |
+
|
487 |
+
def set_hparam(self, name, value):
|
488 |
+
"""Set the value of an existing hyperparameter.
|
489 |
+
|
490 |
+
This function verifies that the type of the value matches the type of the
|
491 |
+
existing hyperparameter.
|
492 |
+
|
493 |
+
Args:
|
494 |
+
name: Name of the hyperparameter.
|
495 |
+
value: New value of the hyperparameter.
|
496 |
+
|
497 |
+
Raises:
|
498 |
+
ValueError: If there is a type mismatch.
|
499 |
+
"""
|
500 |
+
param_type, is_list = self._hparam_types[name]
|
501 |
+
if isinstance(value, list):
|
502 |
+
if not is_list:
|
503 |
+
raise ValueError(
|
504 |
+
'Must not pass a list for single-valued parameter: %s' % name)
|
505 |
+
setattr(self, name, [
|
506 |
+
_cast_to_type_if_compatible(name, param_type, v) for v in value])
|
507 |
+
else:
|
508 |
+
if is_list:
|
509 |
+
raise ValueError(
|
510 |
+
'Must pass a list for multi-valued parameter: %s.' % name)
|
511 |
+
setattr(self, name, _cast_to_type_if_compatible(name, param_type, value))
|
512 |
+
|
513 |
+
def del_hparam(self, name):
|
514 |
+
"""Removes the hyperparameter with key 'name'.
|
515 |
+
|
516 |
+
Args:
|
517 |
+
name: Name of the hyperparameter.
|
518 |
+
"""
|
519 |
+
if hasattr(self, name):
|
520 |
+
delattr(self, name)
|
521 |
+
del self._hparam_types[name]
|
522 |
+
|
523 |
+
def parse(self, values):
|
524 |
+
"""Override hyperparameter values, parsing new values from a string.
|
525 |
+
|
526 |
+
See parse_values for more detail on the allowed format for values.
|
527 |
+
|
528 |
+
Args:
|
529 |
+
values: String. Comma separated list of `name=value` pairs where
|
530 |
+
'value' must follow the syntax described above.
|
531 |
+
|
532 |
+
Returns:
|
533 |
+
The `HParams` instance.
|
534 |
+
|
535 |
+
Raises:
|
536 |
+
ValueError: If `values` cannot be parsed.
|
537 |
+
"""
|
538 |
+
type_map = dict()
|
539 |
+
for name, t in self._hparam_types.items():
|
540 |
+
param_type, _ = t
|
541 |
+
type_map[name] = param_type
|
542 |
+
|
543 |
+
values_map = parse_values(values, type_map)
|
544 |
+
return self.override_from_dict(values_map)
|
545 |
+
|
546 |
+
def override_from_dict(self, values_dict):
|
547 |
+
"""Override hyperparameter values, parsing new values from a dictionary.
|
548 |
+
|
549 |
+
Args:
|
550 |
+
values_dict: Dictionary of name:value pairs.
|
551 |
+
|
552 |
+
Returns:
|
553 |
+
The `HParams` instance.
|
554 |
+
|
555 |
+
Raises:
|
556 |
+
ValueError: If `values_dict` cannot be parsed.
|
557 |
+
"""
|
558 |
+
for name, value in values_dict.items():
|
559 |
+
self.set_hparam(name, value)
|
560 |
+
return self
|
561 |
+
|
562 |
+
## @deprecation.deprecated(None, 'Use `override_from_dict`.')
|
563 |
+
def set_from_map(self, values_map):
|
564 |
+
"""DEPRECATED. Use override_from_dict."""
|
565 |
+
return self.override_from_dict(values_dict=values_map)
|
566 |
+
|
567 |
+
def set_model_structure(self, model_structure):
|
568 |
+
self._model_structure = model_structure
|
569 |
+
|
570 |
+
def get_model_structure(self):
|
571 |
+
return self._model_structure
|
572 |
+
|
573 |
+
def to_json(self, indent=None, separators=None, sort_keys=False):
|
574 |
+
"""Serializes the hyperparameters into JSON.
|
575 |
+
|
576 |
+
Args:
|
577 |
+
indent: If a non-negative integer, JSON array elements and object members
|
578 |
+
will be pretty-printed with that indent level. An indent level of 0, or
|
579 |
+
negative, will only insert newlines. `None` (the default) selects the
|
580 |
+
most compact representation.
|
581 |
+
separators: Optional `(item_separator, key_separator)` tuple. Default is
|
582 |
+
`(', ', ': ')`.
|
583 |
+
sort_keys: If `True`, the output dictionaries will be sorted by key.
|
584 |
+
|
585 |
+
Returns:
|
586 |
+
A JSON string.
|
587 |
+
"""
|
588 |
+
return json.dumps(
|
589 |
+
self.values(),
|
590 |
+
indent=indent,
|
591 |
+
separators=separators,
|
592 |
+
sort_keys=sort_keys)
|
593 |
+
|
594 |
+
def parse_json(self, values_json):
|
595 |
+
"""Override hyperparameter values, parsing new values from a json object.
|
596 |
+
|
597 |
+
Args:
|
598 |
+
values_json: String containing a json object of name:value pairs.
|
599 |
+
|
600 |
+
Returns:
|
601 |
+
The `HParams` instance.
|
602 |
+
|
603 |
+
Raises:
|
604 |
+
ValueError: If `values_json` cannot be parsed.
|
605 |
+
"""
|
606 |
+
values_map = json.loads(values_json)
|
607 |
+
return self.override_from_dict(values_map)
|
608 |
+
|
609 |
+
def values(self):
|
610 |
+
"""Return the hyperparameter values as a Python dictionary.
|
611 |
+
|
612 |
+
Returns:
|
613 |
+
A dictionary with hyperparameter names as keys. The values are the
|
614 |
+
hyperparameter values.
|
615 |
+
"""
|
616 |
+
return {n: getattr(self, n) for n in self._hparam_types.keys()}
|
617 |
+
|
618 |
+
def get(self, key, default=None):
|
619 |
+
"""Returns the value of `key` if it exists, else `default`."""
|
620 |
+
if key in self._hparam_types:
|
621 |
+
# Ensure that default is compatible with the parameter type.
|
622 |
+
if default is not None:
|
623 |
+
param_type, is_param_list = self._hparam_types[key]
|
624 |
+
type_str = 'list<%s>' % param_type if is_param_list else str(param_type)
|
625 |
+
fail_msg = ("Hparam '%s' of type '%s' is incompatible with "
|
626 |
+
'default=%s' % (key, type_str, default))
|
627 |
+
|
628 |
+
is_default_list = isinstance(default, list)
|
629 |
+
if is_param_list != is_default_list:
|
630 |
+
raise ValueError(fail_msg)
|
631 |
+
|
632 |
+
try:
|
633 |
+
if is_default_list:
|
634 |
+
for value in default:
|
635 |
+
_cast_to_type_if_compatible(key, param_type, value)
|
636 |
+
else:
|
637 |
+
_cast_to_type_if_compatible(key, param_type, default)
|
638 |
+
except ValueError as e:
|
639 |
+
raise ValueError('%s. %s' % (fail_msg, e))
|
640 |
+
|
641 |
+
return getattr(self, key)
|
642 |
+
|
643 |
+
return default
|
644 |
+
|
645 |
+
def __contains__(self, key):
|
646 |
+
return key in self._hparam_types
|
647 |
+
|
648 |
+
def __str__(self):
|
649 |
+
return str(sorted(self.values().items()))
|
650 |
+
|
651 |
+
def __repr__(self):
|
652 |
+
return '%s(%s)' % (type(self).__name__, self.__str__())
|
653 |
+
|
654 |
+
@staticmethod
|
655 |
+
def _get_kind_name(param_type, is_list):
|
656 |
+
"""Returns the field name given parameter type and is_list.
|
657 |
+
|
658 |
+
Args:
|
659 |
+
param_type: Data type of the hparam.
|
660 |
+
is_list: Whether this is a list.
|
661 |
+
|
662 |
+
Returns:
|
663 |
+
A string representation of the field name.
|
664 |
+
|
665 |
+
Raises:
|
666 |
+
ValueError: If parameter type is not recognized.
|
667 |
+
"""
|
668 |
+
if issubclass(param_type, bool):
|
669 |
+
# This check must happen before issubclass(param_type, six.integer_types),
|
670 |
+
# since Python considers bool to be a subclass of int.
|
671 |
+
typename = 'bool'
|
672 |
+
elif issubclass(param_type, six.integer_types):
|
673 |
+
# Setting 'int' and 'long' types to be 'int64' to ensure the type is
|
674 |
+
# compatible with both Python2 and Python3.
|
675 |
+
typename = 'int64'
|
676 |
+
elif issubclass(param_type, (six.string_types, six.binary_type)):
|
677 |
+
# Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is
|
678 |
+
# compatible with both Python2 and Python3.
|
679 |
+
typename = 'bytes'
|
680 |
+
elif issubclass(param_type, float):
|
681 |
+
typename = 'float'
|
682 |
+
else:
|
683 |
+
raise ValueError('Unsupported parameter type: %s' % str(param_type))
|
684 |
+
|
685 |
+
suffix = 'list' if is_list else 'value'
|
686 |
+
return '_'.join([typename, suffix])
|
687 |
+
|
688 |
+
## def to_proto(self, export_scope=None): # pylint: disable=unused-argument
|
689 |
+
## """Converts a `HParams` object to a `HParamDef` protocol buffer.
|
690 |
+
##
|
691 |
+
## Args:
|
692 |
+
## export_scope: Optional `string`. Name scope to remove.
|
693 |
+
##
|
694 |
+
## Returns:
|
695 |
+
## A `HParamDef` protocol buffer.
|
696 |
+
## """
|
697 |
+
## hparam_proto = hparam_pb2.HParamDef()
|
698 |
+
## for name in self._hparam_types:
|
699 |
+
## # Parse the values.
|
700 |
+
## param_type, is_list = self._hparam_types.get(name, (None, None))
|
701 |
+
## kind = HParams._get_kind_name(param_type, is_list)
|
702 |
+
##
|
703 |
+
## if is_list:
|
704 |
+
## if kind.startswith('bytes'):
|
705 |
+
## v_list = [compat.as_bytes(v) for v in getattr(self, name)]
|
706 |
+
## else:
|
707 |
+
## v_list = [v for v in getattr(self, name)]
|
708 |
+
## getattr(hparam_proto.hparam[name], kind).value.extend(v_list)
|
709 |
+
## else:
|
710 |
+
## v = getattr(self, name)
|
711 |
+
## if kind.startswith('bytes'):
|
712 |
+
## v = compat.as_bytes(getattr(self, name))
|
713 |
+
## setattr(hparam_proto.hparam[name], kind, v)
|
714 |
+
##
|
715 |
+
## return hparam_proto
|
716 |
+
|
717 |
+
## @staticmethod
|
718 |
+
## def from_proto(hparam_def, import_scope=None): # pylint: disable=unused-argument
|
719 |
+
## return HParams(hparam_def=hparam_def)
|
720 |
+
|
721 |
+
|
722 |
+
## ops.register_proto_function(
|
723 |
+
## 'hparams',
|
724 |
+
## proto_type=hparam_pb2.HParamDef,
|
725 |
+
## to_proto=HParams.to_proto,
|
726 |
+
## from_proto=HParams.from_proto)
|
tfcompat/readme.md
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Source: hparam.py copied from tensorflow v1.12.0.
|
2 |
+
|
3 |
+
https://github.com/tensorflow/tensorflow/blob/v1.12.0/tensorflow/contrib/training/python/training/hparam.py
|
4 |
+
|
5 |
+
with the following:
|
6 |
+
wget https://github.com/tensorflow/tensorflow/raw/v1.12.0/tensorflow/contrib/training/python/training/hparam.py
|
7 |
+
|
8 |
+
Once all other tensorflow dependencies of these file are removed, the class keeps its goal. Functions not available due to this process are not used in this project.
|
utils.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from scipy import signal
|
5 |
+
from librosa.filters import mel
|
6 |
+
from scipy.signal import get_window
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
def butter_highpass(cutoff, fs, order=5):
|
13 |
+
nyq = 0.5 * fs
|
14 |
+
normal_cutoff = cutoff / nyq
|
15 |
+
b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
|
16 |
+
return b, a
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
def pySTFT(x, fft_length=1024, hop_length=256):
|
21 |
+
|
22 |
+
x = np.pad(x, int(fft_length//2), mode='reflect')
|
23 |
+
|
24 |
+
noverlap = fft_length - hop_length
|
25 |
+
shape = x.shape[:-1]+((x.shape[-1]-noverlap)//hop_length, fft_length)
|
26 |
+
strides = x.strides[:-1]+(hop_length*x.strides[-1], x.strides[-1])
|
27 |
+
result = np.lib.stride_tricks.as_strided(x, shape=shape,
|
28 |
+
strides=strides)
|
29 |
+
|
30 |
+
fft_window = get_window('hann', fft_length, fftbins=True)
|
31 |
+
result = np.fft.rfft(fft_window * result, n=fft_length).T
|
32 |
+
|
33 |
+
return np.abs(result)
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
class LinearNorm(torch.nn.Module):
|
38 |
+
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
39 |
+
super(LinearNorm, self).__init__()
|
40 |
+
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
41 |
+
|
42 |
+
torch.nn.init.xavier_uniform_(
|
43 |
+
self.linear_layer.weight,
|
44 |
+
gain=torch.nn.init.calculate_gain(w_init_gain))
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
return self.linear_layer(x)
|
48 |
+
|
49 |
+
|
50 |
+
class ConvNorm(torch.nn.Module):
|
51 |
+
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
|
52 |
+
padding=None, dilation=1, bias=True, w_init_gain='linear'):
|
53 |
+
super(ConvNorm, self).__init__()
|
54 |
+
if padding is None:
|
55 |
+
assert(kernel_size % 2 == 1)
|
56 |
+
padding = int(dilation * (kernel_size - 1) / 2)
|
57 |
+
|
58 |
+
self.conv = torch.nn.Conv1d(in_channels, out_channels,
|
59 |
+
kernel_size=kernel_size, stride=stride,
|
60 |
+
padding=padding, dilation=dilation,
|
61 |
+
bias=bias)
|
62 |
+
|
63 |
+
torch.nn.init.xavier_uniform_(
|
64 |
+
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
|
65 |
+
|
66 |
+
def forward(self, signal):
|
67 |
+
conv_signal = self.conv(signal)
|
68 |
+
return conv_signal
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
def filter_bank_mean(num_rep, codes_mask, max_len_long):
|
73 |
+
'''
|
74 |
+
num_rep (B, L)
|
75 |
+
codes_mask (B, L)
|
76 |
+
|
77 |
+
output: filterbank (B, L, max_len_fake)
|
78 |
+
|
79 |
+
zero pad in codes must be real zero
|
80 |
+
'''
|
81 |
+
|
82 |
+
num_rep = num_rep.unsqueeze(-1) # (B, L, 1)
|
83 |
+
codes_mask = codes_mask.unsqueeze(-1) # (B, L, 1)
|
84 |
+
num_rep = num_rep * codes_mask
|
85 |
+
|
86 |
+
right_edge = num_rep.cumsum(dim=1)
|
87 |
+
left_edge = torch.zeros_like(right_edge)
|
88 |
+
left_edge[:, 1:, :] = right_edge[:, :-1, :]
|
89 |
+
right_edge = right_edge.ceil()
|
90 |
+
left_edge = left_edge.floor()
|
91 |
+
|
92 |
+
index = torch.arange(1, max_len_long+1, device=num_rep.device).view(1, 1, -1)
|
93 |
+
|
94 |
+
lower = index - left_edge
|
95 |
+
|
96 |
+
right_edge_flip = max_len_long - right_edge
|
97 |
+
|
98 |
+
upper = (index - right_edge_flip).flip(dims=(2,))
|
99 |
+
|
100 |
+
# triangular pooling
|
101 |
+
fb = F.relu(torch.min(lower, upper)).float()
|
102 |
+
|
103 |
+
# mean pooling
|
104 |
+
fb = (fb > 0).float()
|
105 |
+
|
106 |
+
norm = fb.sum(dim=-1, keepdim=True)
|
107 |
+
norm[norm==0] = 1.0
|
108 |
+
|
109 |
+
fb = fb / norm
|
110 |
+
|
111 |
+
return fb * codes_mask
|