jonathanjordan21 commited on
Commit
c021d8e
1 Parent(s): 48f5453

67809715652a92b22870c50ad30f6ff38e292006aedc75ddbdc828aa856ef68f

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2019
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,71 @@
1
- ---
2
- title: Tts Rvc Autopst
3
- emoji: 💬
4
- colorFrom: yellow
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 4.36.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Global Prosody Style Transfer Without Text Transcriptions
2
+
3
+ This repository provides a PyTorch implementation of [AutoPST](https://arxiv.org/abs/2106.08519), which enables unsupervised global prosody conversion without text transcriptions.
4
+
5
+ This is a short video that explains the main concepts of our work. If you find this work useful and use it in your research, please consider citing our paper.
6
+
7
+ [![SpeechSplit](./assets/cover.png)](https://youtu.be/wow2DRuJ69c/)
8
+
9
+ ```
10
+ @InProceedings{pmlr-v139-qian21b,
11
+ title = {Global Prosody Style Transfer Without Text Transcriptions},
12
+ author = {Qian, Kaizhi and Zhang, Yang and Chang, Shiyu and Xiong, Jinjun and Gan, Chuang and Cox, David and Hasegawa-Johnson, Mark},
13
+ booktitle = {Proceedings of the 38th International Conference on Machine Learning},
14
+ pages = {8650--8660},
15
+ year = {2021},
16
+ editor = {Meila, Marina and Zhang, Tong},
17
+ volume = {139},
18
+ series = {Proceedings of Machine Learning Research},
19
+ month = {18--24 Jul},
20
+ publisher = {PMLR},
21
+ url = {http://proceedings.mlr.press/v139/qian21b.html}
22
+ }
23
+
24
+ ```
25
+
26
+
27
+ ## Audio Demo
28
+
29
+ The audio demo for AutoPST can be found [here](https://auspicious3000.github.io/AutoPST-Demo/)
30
+
31
+ ## Dependencies
32
+ - Python 3.6
33
+ - Numpy
34
+ - Scipy
35
+ - PyTorch == v1.6.0
36
+ - librosa
37
+ - pysptk
38
+ - soundfile
39
+ - wavenet_vocoder ```pip install wavenet_vocoder==0.1.1```
40
+ for more information, please refer to https://github.com/r9y9/wavenet_vocoder
41
+
42
+
43
+ ## To Run Demo
44
+
45
+ Download [pre-trained models](https://drive.google.com/file/d/1ji3Bk6YGvXkPqFu1hLOAJp_SKw-vHGrp/view?usp=sharing) to ```assets```
46
+
47
+ Download the same WaveNet vocoder model as in [AutoVC](https://github.com/auspicious3000/autovc) to ```assets```
48
+
49
+ The fast and high-quality hifi-gan v1 (https://github.com/jik876/hifi-gan) pre-trained model is now available [here.](https://drive.google.com/file/d/1n76jHs8k1sDQ3Eh5ajXwdxuY_EZw4N9N/view?usp=sharing)
50
+
51
+ Please refer to [AutoVC](https://github.com/auspicious3000/autovc) if you have any problems with the vocoder part, because they share the same vocoder scripts.
52
+
53
+ Run ```demo.ipynb```
54
+
55
+
56
+ ## To Train
57
+
58
+ Download [training data](https://drive.google.com/file/d/1H1dyA80qREKLHybqnYaqBRRsacIdFbnE/view?usp=sharing) to ```assets```.
59
+ The provided training data is very small for code verification purpose only.
60
+ Please use the scripts to prepare your own data for training.
61
+
62
+ 1. Prepare training data: ```python prepare_train_data.py```
63
+
64
+ 2. Train 1st Stage: ```python main_1.py```
65
+
66
+ 3. Train 2nd Stage: ```python main_2.py```
67
+
68
+
69
+ ## Final Words
70
+
71
+ This project is part of an ongoing research. We hope this repo is useful for your research. If you need any help or have any suggestions on improving the framework, please raise an issue and we will do our best to get back to you as soon as possible.
assets/cover.png ADDED
assets/mfcc_stats.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e34f5e00e6eb8591e8dcc3796a56a048c8512245ca552c83069a4c8eb3a57387
3
+ size 52719
assets/spk2emb_82.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea97c16a4d1d2eca10a0481630daa8590d9d8a1e27a87e53fa4340c17440745a
3
+ size 32133
assets/test_vctk.meta ADDED
Binary file (200 kB). View file
 
data_loader.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import torch
4
+ import numpy as np
5
+
6
+ from numpy.random import uniform
7
+ from torch.utils import data
8
+ from torch.utils.data.sampler import Sampler
9
+ from multiprocessing import Process, Manager
10
+
11
+
12
+
13
+ class Utterances(data.Dataset):
14
+ """Dataset class for the Utterances dataset."""
15
+
16
+ def __init__(self, hparams):
17
+ """Initialize and preprocess the Utterances dataset."""
18
+ self.meta_file = hparams.meta_file
19
+
20
+ self.feat_dir_1 = hparams.feat_dir_1
21
+ self.feat_dir_2 = hparams.feat_dir_2
22
+ self.feat_dir_3 = hparams.feat_dir_3
23
+
24
+ self.step = 4
25
+ self.split = 0
26
+
27
+ self.max_len_pad = hparams.max_len_pad
28
+
29
+ meta = pickle.load(open(self.meta_file, "rb"))
30
+
31
+ manager = Manager()
32
+ meta = manager.list(meta)
33
+ dataset = manager.list(len(meta)*[None]) # <-- can be shared between processes.
34
+ processes = []
35
+ for i in range(0, len(meta), self.step):
36
+ p = Process(target=self.load_data,
37
+ args=(meta[i:i+self.step],dataset,i))
38
+ p.start()
39
+ processes.append(p)
40
+ for p in processes:
41
+ p.join()
42
+
43
+ # very importtant to do dataset = list(dataset)
44
+ self.train_dataset = list(dataset)
45
+ self.num_tokens = len(self.train_dataset)
46
+
47
+ print('Finished loading the {} Utterances training dataset...'.format(self.num_tokens))
48
+
49
+
50
+ def load_data(self, submeta, dataset, idx_offset):
51
+ for k, sbmt in enumerate(submeta):
52
+ uttrs = len(sbmt)*[None]
53
+ for j, tmp in enumerate(sbmt):
54
+ if j < 2:
55
+ # fill in speaker name and embedding
56
+ uttrs[j] = tmp
57
+ else:
58
+ # fill in data
59
+ sp_tmp = np.load(os.path.join(self.feat_dir_1, tmp))
60
+ cep_tmp = np.load(os.path.join(self.feat_dir_2, tmp))[:, 0:14]
61
+ cd_tmp = np.load(os.path.join(self.feat_dir_3, tmp))
62
+
63
+ assert len(sp_tmp) == len(cep_tmp) == len(cd_tmp)
64
+
65
+ uttrs[j] = ( np.clip(sp_tmp, 0, 1), cep_tmp, cd_tmp )
66
+ dataset[idx_offset+k] = uttrs
67
+
68
+
69
+ def segment_np(self, cd_long, tau=2):
70
+
71
+ cd_norm = np.sqrt((cd_long ** 2).sum(axis=-1, keepdims=True))
72
+ G = (cd_long @ cd_long.T) / (cd_norm @ cd_norm.T)
73
+
74
+ L = G.shape[0]
75
+
76
+ num_rep = []
77
+ num_rep_sync = []
78
+
79
+ prev_boundary = 0
80
+ rate = np.random.uniform(0.8, 1.3)
81
+
82
+ for t in range(1, L+1):
83
+ if t==L:
84
+ num_rep.append(t - prev_boundary)
85
+ num_rep_sync.append(t - prev_boundary)
86
+ prev_boundary = t
87
+ if t < L:
88
+ q = np.random.uniform(rate-0.1, rate)
89
+ tmp = G[prev_boundary, max(prev_boundary-20, 0):min(prev_boundary+20, L)]
90
+ if q <= 1:
91
+ epsilon = np.quantile(tmp, q)
92
+ if np.all(G[prev_boundary, t:min(t+tau, L)] < epsilon):
93
+ num_rep.append(t - prev_boundary)
94
+ num_rep_sync.append(t - prev_boundary)
95
+ prev_boundary = t
96
+ else:
97
+ epsilon = np.quantile(tmp, 2-q)
98
+ if np.all(G[prev_boundary, t:min(t+tau, L)] < epsilon):
99
+ num_rep.append(t - prev_boundary)
100
+ else:
101
+ num_rep.extend([t-prev_boundary-0.5, 0.5])
102
+
103
+ num_rep_sync.append(t - prev_boundary)
104
+ prev_boundary = t
105
+
106
+ num_rep = np.array(num_rep)
107
+ num_rep_sync = np.array(num_rep_sync)
108
+
109
+ return num_rep, num_rep_sync
110
+
111
+
112
+ def __getitem__(self, index):
113
+ """Return M uttrs for one spkr."""
114
+ dataset = self.train_dataset
115
+
116
+ list_uttrs = dataset[index]
117
+
118
+ emb_org = list_uttrs[1]
119
+
120
+ uttr = np.random.randint(2, len(list_uttrs))
121
+ melsp, melcep, cd_real = list_uttrs[uttr]
122
+
123
+ num_rep, num_rep_sync = self.segment_np(cd_real)
124
+
125
+ return melsp, melcep, cd_real, num_rep, num_rep_sync, len(melsp), len(num_rep), len(num_rep_sync), emb_org
126
+
127
+
128
+ def __len__(self):
129
+ """Return the number of spkrs."""
130
+ return self.num_tokens
131
+
132
+
133
+
134
+ class MyCollator(object):
135
+ def __init__(self, hparams):
136
+ self.max_len_pad = hparams.max_len_pad
137
+
138
+ def __call__(self, batch):
139
+ new_batch = []
140
+
141
+ l_short_max = 0
142
+ l_short_sync_max = 0
143
+ l_real_max = 0
144
+
145
+ for token in batch:
146
+ sp_real, cep_real, cd_real, rep, rep_sync, l_real, l_short, l_short_sync, emb = token
147
+
148
+ if l_short > l_short_max:
149
+ l_short_max = l_short
150
+
151
+ if l_short_sync > l_short_sync_max:
152
+ l_short_sync_max = l_short_sync
153
+
154
+ if l_real > l_real_max:
155
+ l_real_max = l_real
156
+
157
+ sp_real_pad = np.pad(sp_real, ((0,self.max_len_pad-l_real),(0,0)), 'constant')
158
+ cep_real_pad = np.pad(cep_real, ((0,self.max_len_pad-l_real),(0,0)), 'constant')
159
+ cd_real_pad = np.pad(cd_real, ((0,self.max_len_pad-l_real),(0,0)), 'constant')
160
+
161
+ rep_pad = np.pad(rep, (0,self.max_len_pad-l_short), 'constant')
162
+ rep_sync_pad = np.pad(rep_sync, (0,self.max_len_pad-l_short_sync), 'constant')
163
+
164
+ new_batch.append( (sp_real_pad, cep_real_pad, cd_real_pad, rep_pad, rep_sync_pad, l_real, l_short, l_short_sync, emb) )
165
+
166
+ batch = new_batch
167
+
168
+ a, b, c, d, e, f, g, h, i = zip(*batch)
169
+
170
+ sp_real = torch.from_numpy(np.stack(a, axis=0))[:,:l_real_max+1,:]
171
+ cep_real = torch.from_numpy(np.stack(b, axis=0))[:,:l_real_max+1,:]
172
+ cd_real = torch.from_numpy(np.stack(c, axis=0))[:,:l_real_max+1,:]
173
+ num_rep = torch.from_numpy(np.stack(d, axis=0))[:,:l_short_max+1]
174
+ num_rep_sync = torch.from_numpy(np.stack(e, axis=0))[:,:l_short_sync_max+1]
175
+
176
+ len_real = torch.from_numpy(np.stack(f, axis=0))
177
+ len_short = torch.from_numpy(np.stack(g, axis=0))
178
+ len_short_sync = torch.from_numpy(np.stack(h, axis=0))
179
+
180
+ spk_emb = torch.from_numpy(np.stack(i, axis=0))
181
+
182
+ return sp_real, cep_real, cd_real, num_rep, num_rep_sync, len_real, len_short, len_short_sync, spk_emb
183
+
184
+
185
+
186
+ class MultiSampler(Sampler):
187
+ """Samples elements more than once in a single pass through the data.
188
+ """
189
+ def __init__(self, num_samples, n_repeats, shuffle=False):
190
+ self.num_samples = num_samples
191
+ self.n_repeats = n_repeats
192
+ self.shuffle = shuffle
193
+
194
+ def gen_sample_array(self):
195
+ self.sample_idx_array = torch.arange(self.num_samples, dtype=torch.int64).repeat(self.n_repeats)
196
+ if self.shuffle:
197
+ self.sample_idx_array = self.sample_idx_array[torch.randperm(len(self.sample_idx_array))]
198
+ return self.sample_idx_array
199
+
200
+ def __iter__(self):
201
+ return iter(self.gen_sample_array())
202
+
203
+ def __len__(self):
204
+ return len(self.sample_idx_array)
205
+
206
+
207
+
208
+ def worker_init_fn(x):
209
+ return np.random.seed((torch.initial_seed()) % (2**32))
210
+
211
+ def get_loader(hparams):
212
+ """Build and return a data loader."""
213
+
214
+ dataset = Utterances(hparams)
215
+
216
+ my_collator = MyCollator(hparams)
217
+
218
+ sampler = MultiSampler(len(dataset), hparams.samplier, shuffle=hparams.shuffle)
219
+
220
+ data_loader = data.DataLoader(dataset=dataset,
221
+ batch_size=hparams.batch_size,
222
+ sampler=sampler,
223
+ num_workers=hparams.num_workers,
224
+ drop_last=True,
225
+ pin_memory=False,
226
+ worker_init_fn=worker_init_fn,
227
+ collate_fn=my_collator)
228
+ return data_loader
demo.ipynb ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "import pickle\n",
11
+ "import numpy as np\n",
12
+ "import torch\n",
13
+ "import torch.nn.functional as F\n",
14
+ "from collections import OrderedDict\n",
15
+ "from onmt_modules.misc import sequence_mask\n",
16
+ "from model_autopst import Generator_2 as Predictor\n",
17
+ "from hparams_autopst import hparams\n",
18
+ "\n",
19
+ "device = 'cuda:0'\n",
20
+ "\n",
21
+ "P = Predictor(hparams).eval().to(device)\n",
22
+ "\n",
23
+ "checkpoint = torch.load('./assets/580000-P.ckpt', map_location=lambda storage, loc: storage) \n",
24
+ "P.load_state_dict(checkpoint['model'], strict=True)\n",
25
+ "print('Loaded predictor .....................................................')\n",
26
+ "\n",
27
+ "dict_test = pickle.load(open('./assets/test_vctk.meta', 'rb'))"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": null,
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "spect_vc = OrderedDict()\n",
37
+ "\n",
38
+ "uttrs = [('p231', 'p270', '001'),\n",
39
+ " ('p270', 'p231', '001'),\n",
40
+ " ('p231', 'p245', '003001'),\n",
41
+ " ('p245', 'p231', '003001'),\n",
42
+ " ('p239', 'p270', '024002'),\n",
43
+ " ('p270', 'p239', '024002')]\n",
44
+ "\n",
45
+ "\n",
46
+ "for uttr in uttrs:\n",
47
+ " \n",
48
+ " cep_real, spk_emb = dict_test[uttr[0]][uttr[2]]\n",
49
+ " cep_real_A = torch.from_numpy(cep_real).unsqueeze(0).to(device)\n",
50
+ " len_real_A = torch.tensor(cep_real_A.size(1)).unsqueeze(0).to(device)\n",
51
+ " real_mask_A = sequence_mask(len_real_A, cep_real_A.size(1)).float()\n",
52
+ " \n",
53
+ " _, spk_emb = dict_test[uttr[1]][uttr[2]]\n",
54
+ " spk_emb_B = torch.from_numpy(spk_emb).unsqueeze(0).to(device)\n",
55
+ " \n",
56
+ " with torch.no_grad():\n",
57
+ " spect_output, len_spect = P.infer_onmt(cep_real_A.transpose(2,1)[:,:14,:],\n",
58
+ " real_mask_A,\n",
59
+ " len_real_A,\n",
60
+ " spk_emb_B)\n",
61
+ " \n",
62
+ " uttr_tgt = spect_output[:len_spect[0],0,:].cpu().numpy()\n",
63
+ " \n",
64
+ " spect_vc[f'{uttr[0]}_{uttr[1]}_{uttr[2]}'] = uttr_tgt"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": null,
70
+ "metadata": {},
71
+ "outputs": [],
72
+ "source": [
73
+ "# spectrogram to waveform\n",
74
+ "# Feel free to use other vocoders\n",
75
+ "# This cell requires some preparation to work, please see the corresponding part in AutoVC\n",
76
+ "import torch\n",
77
+ "import librosa\n",
78
+ "import pickle\n",
79
+ "import os\n",
80
+ "from synthesis import build_model\n",
81
+ "from synthesis import wavegen\n",
82
+ "\n",
83
+ "model = build_model().to(device)\n",
84
+ "checkpoint = torch.load(\"./assets/checkpoint_step001000000_ema.pth\")\n",
85
+ "model.load_state_dict(checkpoint[\"state_dict\"])\n",
86
+ "\n",
87
+ "for name, sp in spect_vc.items():\n",
88
+ " print(name)\n",
89
+ " waveform = wavegen(model, c=sp) \n",
90
+ " librosa.output.write_wav('./assets/'+name+'.wav', waveform, sr=16000)"
91
+ ]
92
+ }
93
+ ],
94
+ "metadata": {
95
+ "kernelspec": {
96
+ "display_name": "Python 3",
97
+ "language": "python",
98
+ "name": "python3"
99
+ },
100
+ "language_info": {
101
+ "codemirror_mode": {
102
+ "name": "ipython",
103
+ "version": 3
104
+ },
105
+ "file_extension": ".py",
106
+ "mimetype": "text/x-python",
107
+ "name": "python",
108
+ "nbconvert_exporter": "python",
109
+ "pygments_lexer": "ipython3",
110
+ "version": "3.7.5"
111
+ }
112
+ },
113
+ "nbformat": 4,
114
+ "nbformat_minor": 4
115
+ }
fast_decoders.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from onmt_modules.misc import sequence_mask
6
+
7
+
8
+ class DecodeFunc_Sp(object):
9
+ """
10
+ Decoding functions
11
+ """
12
+ def __init__(self, hparams, type_out):
13
+
14
+ if type_out == 'Sp':
15
+ self.dim_freq = hparams.dim_freq
16
+ self.max_decoder_steps = hparams.dec_steps_sp
17
+ elif type_out == 'Tx':
18
+ self.dim_freq = hparams.dim_code
19
+ self.max_decoder_steps = hparams.dec_steps_tx
20
+ else:
21
+ raise ValueError
22
+
23
+ self.gate_threshold = hparams.gate_threshold
24
+ self.type_out = type_out
25
+
26
+ def __call__(self, tgt, memory_bank, memory_lengths, decoder, postnet):
27
+
28
+ dec_outs, attns = decoder(tgt, memory_bank, step=None,
29
+ memory_lengths=memory_lengths)
30
+ spect_gate = postnet(dec_outs)
31
+ spect, gate = spect_gate[:, :, 1:], spect_gate[:, :, :1]
32
+
33
+ return spect, gate
34
+
35
+
36
+ def infer(self, tgt_real, memory_bank, memory_lengths, decoder, postnet):
37
+ B = memory_bank.size(1)
38
+ device = memory_bank.device
39
+
40
+ spect_outputs = torch.zeros((self.max_decoder_steps, B, self.dim_freq),
41
+ dtype=torch.float, device=device)
42
+ gate_outputs = torch.zeros((self.max_decoder_steps, B, 1),
43
+ dtype=torch.float, device=device)
44
+ tgt_words = torch.zeros([B, 1],
45
+ dtype=torch.float, device=device)
46
+
47
+ current_pred = torch.zeros([1, B, self.dim_freq],
48
+ dtype=torch.float, device=device)
49
+
50
+ for t in range(self.max_decoder_steps):
51
+
52
+ dec_outs, _ = decoder(current_pred,
53
+ memory_bank, t,
54
+ memory_lengths=memory_lengths,
55
+ tgt_words=tgt_words)
56
+ spect_gate = postnet(dec_outs)
57
+
58
+ spect, gate = spect_gate[:, :, 1:], spect_gate[:, :, :1]
59
+ spect_outputs[t:t+1] = spect
60
+ gate_outputs[t:t+1] = gate
61
+
62
+ stop = (torch.sigmoid(gate) - self.gate_threshold + 0.5).round()
63
+ current_pred = spect.data
64
+ tgt_words = stop.squeeze(-1).t()
65
+
66
+ if t == self.max_decoder_steps - 1:
67
+ print(f"Warning! {self.type_out} reached max decoder steps")
68
+
69
+ if (stop == 1).all():
70
+ break
71
+
72
+ stop_quant = (torch.sigmoid(gate_outputs.data) - self.gate_threshold + 0.5).round().squeeze(-1)
73
+ len_spect = (stop_quant.cumsum(dim=0)==0).sum(dim=0)
74
+
75
+ return spect_outputs, len_spect, gate_outputs
hparams_autopst.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tfcompat.hparam import HParams
2
+
3
+ # NOTE: If you want full control for model architecture. please take a look
4
+ # at the code and change whatever you want. Some hyper parameters are hardcoded.
5
+
6
+ # Default hyperparameters:
7
+ hparams = HParams(
8
+
9
+ # sea params
10
+ dim_neck_sea = 4,
11
+ dim_freq_sea = 14,
12
+ dim_enc_sea = 512,
13
+
14
+ # autopst params
15
+ dim_freq = 80,
16
+ dim_code = 4,
17
+ dim_spk = 82,
18
+ dim_sty = 128,
19
+ gate_threshold = 0.48,
20
+ dec_steps_tx = 640,
21
+ dec_steps_sp = 806,
22
+ chs_grp = 16,
23
+
24
+ # onmt params
25
+ enc_layers = 4,
26
+ enc_rnn_size = 256,
27
+ dec_layers = 4,
28
+ dec_rnn_size = 256,
29
+ transformer_ff = 1024,
30
+ heads = 8,
31
+ dropout = 0.1,
32
+ attention_dropout = 0.1,
33
+ max_relative_positions = 0,
34
+ copy_attn = False,
35
+ self_attn_type = "scaled-dot",
36
+ aan_useffn = False,
37
+ full_context_alignment = False,
38
+ alignment_layer = 0,
39
+ alignment_heads = 0,
40
+
41
+ # pretrained model
42
+ pretrained_path = './assets/xxx.ckpt',
43
+
44
+ # data loader
45
+ meta_file = './assets/train_vctk.meta',
46
+ feat_dir_1 = './assets/vctk16-train-sp-mel',
47
+ feat_dir_2 = './assets/vctk16-train-cep-mel',
48
+ feat_dir_3 = './assets/vctk16-train-teacher',
49
+ batch_size = 4,
50
+ shuffle = True,
51
+ num_workers = 0,
52
+ samplier = 2,
53
+ max_len_pad = 2048,
54
+ sampling_params = (0.8, 1.3, 0.1),
55
+
56
+ )
57
+
58
+
59
+ def hparams_debug_string():
60
+ values = hparams.values()
61
+ hp = [' %s: %s' % (name, values[name]) for name in values]
62
+ return 'Hyperparameters:\n' + '\n'.join(hp)
hparams_sea.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tfcompat.hparam import HParams
2
+
3
+ # NOTE: If you want full control for model architecture. please take a look
4
+ # at the code and change whatever you want. Some hyper parameters are hardcoded.
5
+
6
+ # Default hyperparameters:
7
+ hparams = HParams(
8
+ dim_neck_sea = 8,
9
+ dim_freq_sea = 20,
10
+ dim_spk = 82,
11
+ dim_enc_sea = 512,
12
+ chs_grp = 16,
13
+ dim_freq_sp = 80,
14
+
15
+ )
16
+
17
+
18
+ def hparams_debug_string():
19
+ values = hparams.values()
20
+ hp = [' %s: %s' % (name, values[name]) for name in values]
21
+ return 'Hyperparameters:\n' + '\n'.join(hp)
main_1.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+
5
+ from solver_1 import Solver
6
+ from data_loader import get_loader
7
+ from hparams_autopst import hparams, hparams_debug_string
8
+
9
+
10
+
11
+ def str2bool(v):
12
+ return v.lower() in ('true')
13
+
14
+ def main(config):
15
+
16
+ # Data loader
17
+ data_loader = get_loader(hparams)
18
+
19
+ # Solver for training
20
+ solver = Solver(data_loader, config, hparams)
21
+
22
+ solver.train()
23
+
24
+
25
+
26
+ if __name__ == '__main__':
27
+ parser = argparse.ArgumentParser()
28
+
29
+ # Training configuration.
30
+ parser.add_argument('--num_iters', type=int, default=1000000)
31
+
32
+ # Miscellaneous.
33
+ parser.add_argument('--device_id', type=int, default=0)
34
+
35
+ # Step size.
36
+ parser.add_argument('--log_step', type=int, default=10)
37
+
38
+ config = parser.parse_args()
39
+ print(config)
40
+ print(hparams_debug_string())
41
+ main(config)
main_2.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+
5
+ from solver_2 import Solver
6
+ from data_loader import get_loader
7
+ from hparams_autopst import hparams, hparams_debug_string
8
+
9
+
10
+
11
+ def str2bool(v):
12
+ return v.lower() in ('true')
13
+
14
+ def main(config):
15
+
16
+ # Data loader
17
+ data_loader = get_loader(hparams)
18
+
19
+ # Solver for training
20
+ solver = Solver(data_loader, config, hparams)
21
+
22
+ solver.train()
23
+
24
+
25
+
26
+ if __name__ == '__main__':
27
+ parser = argparse.ArgumentParser()
28
+
29
+ # Training configuration.
30
+ parser.add_argument('--num_iters', type=int, default=1000000)
31
+
32
+ # Miscellaneous.
33
+ parser.add_argument('--device_id', type=int, default=0)
34
+
35
+ # Step size.
36
+ parser.add_argument('--log_step', type=int, default=10)
37
+
38
+ config = parser.parse_args()
39
+ print(config)
40
+ print(hparams_debug_string())
41
+ main(config)
model_autopst.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ from utils import filter_bank_mean
7
+
8
+ from fast_decoders import DecodeFunc_Sp
9
+
10
+ from model_sea import Encoder_2 as Encoder_Code_2
11
+
12
+ from override_decoder import OnmtDecoder_1 as OnmtDecoder
13
+
14
+ from onmt_modules.misc import sequence_mask
15
+ from onmt_modules.embeddings import PositionalEncoding
16
+ from onmt_modules.encoder_transformer import TransformerEncoder as OnmtEncoder
17
+
18
+
19
+
20
+ class Prenet(nn.Module):
21
+ def __init__(self, dim_input, dim_output, dropout=0.1):
22
+ super().__init__()
23
+
24
+ mlp = nn.Linear(dim_input, dim_output, bias=True)
25
+ pe = PositionalEncoding(dropout, dim_output, 1600)
26
+
27
+ self.make_prenet = nn.Sequential()
28
+ self.make_prenet.add_module('mlp', mlp)
29
+ self.make_prenet.add_module('pe', pe)
30
+
31
+ self.word_padding_idx = 1
32
+
33
+ def forward(self, source, step=None):
34
+
35
+ for i, module in enumerate(self.make_prenet._modules.values()):
36
+ if i == len(self.make_prenet._modules.values()) - 1:
37
+ source = module(source, step=step)
38
+ else:
39
+ source = module(source)
40
+
41
+ return source
42
+
43
+
44
+
45
+ class Decoder_Sp(nn.Module):
46
+ """
47
+ Speech Decoder
48
+ """
49
+ def __init__(self, hparams):
50
+ super().__init__()
51
+
52
+ self.dim_freq = hparams.dim_freq
53
+ self.max_decoder_steps = hparams.dec_steps_sp
54
+ self.gate_threshold = hparams.gate_threshold
55
+
56
+ prenet = Prenet(hparams.dim_freq, hparams.dec_rnn_size)
57
+ self.decoder = OnmtDecoder.from_opt(hparams, prenet)
58
+
59
+ self.postnet = nn.Linear(hparams.dec_rnn_size,
60
+ hparams.dim_freq+1, bias=True)
61
+
62
+ def forward(self, tgt, tgt_lengths, memory_bank, memory_lengths):
63
+
64
+ dec_outs, attns = self.decoder(tgt, memory_bank, step=None,
65
+ memory_lengths=memory_lengths,
66
+ tgt_lengths=tgt_lengths)
67
+ spect_gate = self.postnet(dec_outs)
68
+ spect, gate = spect_gate[:, :, 1:], spect_gate[:, :, :1]
69
+
70
+ return spect, gate
71
+
72
+
73
+
74
+ class Encoder_Tx_Spk(nn.Module):
75
+ """
76
+ Text Encoder
77
+ """
78
+ def __init__(self, hparams):
79
+ super().__init__()
80
+
81
+ prenet = Prenet(hparams.dim_code+hparams.dim_spk,
82
+ hparams.enc_rnn_size)
83
+ self.encoder = OnmtEncoder.from_opt(hparams, prenet)
84
+
85
+ def forward(self, src, src_lengths, spk_emb):
86
+
87
+ spk_emb = spk_emb.unsqueeze(0).expand(src.size(0),-1,-1)
88
+ src_spk = torch.cat((src, spk_emb), dim=-1)
89
+ enc_states, memory_bank, src_lengths = self.encoder(src_spk, src_lengths)
90
+
91
+ return enc_states, memory_bank, src_lengths
92
+
93
+
94
+
95
+ class Decoder_Tx(nn.Module):
96
+ """
97
+ Text Decoder with stop
98
+ and num_rep prediction
99
+ """
100
+ def __init__(self, hparams):
101
+ super().__init__()
102
+
103
+ self.dim_code = hparams.dim_code
104
+ self.max_decoder_steps = hparams.dec_steps_tx
105
+ self.gate_threshold = hparams.gate_threshold
106
+ self.dim_rep = hparams.dim_rep
107
+
108
+ prenet = Prenet(hparams.dim_code, hparams.dec_rnn_size)
109
+ self.decoder = OnmtDecoder.from_opt(hparams, prenet)
110
+
111
+ self.postnet_1 = nn.Linear(hparams.dec_rnn_size,
112
+ hparams.dim_code+1, bias=True)
113
+
114
+ self.postnet_2 = nn.Linear(hparams.dec_rnn_size,
115
+ self.dim_rep, bias=True)
116
+
117
+ def forward(self, tgt, tgt_lengths, memory_bank, memory_lengths):
118
+
119
+ dec_outs, attns = self.decoder(tgt, memory_bank, step=None,
120
+ memory_lengths=memory_lengths,
121
+ tgt_lengths=tgt_lengths)
122
+ gate_text = self.postnet_1(dec_outs)
123
+ rep = self.postnet_2(dec_outs)
124
+ gate, text = gate_text[:, :, :1], gate_text[:, :, 1:]
125
+
126
+ return text, gate, rep
127
+
128
+
129
+
130
+ class Generator_1(nn.Module):
131
+ '''
132
+ sync stage 1
133
+ '''
134
+ def __init__(self, hparams):
135
+ super().__init__()
136
+
137
+ self.encoder_cd = Encoder_Code_2(hparams)
138
+ self.encoder_tx = Encoder_Tx_Spk(hparams)
139
+ self.decoder_sp = Decoder_Sp(hparams)
140
+ self.encoder_spk = nn.Linear(hparams.dim_spk,
141
+ hparams.enc_rnn_size, bias=True)
142
+ self.fast_dec_sp = DecodeFunc_Sp(hparams, 'Sp')
143
+
144
+
145
+ def pad_sequences_rnn(self, cd_short, num_rep, len_long):
146
+ B, L, C = cd_short.size()
147
+ out_tensor = torch.zeros((B, len_long.max(), C), device=cd_short.device)
148
+ '''
149
+ len_long = len_spect + 1
150
+ '''
151
+ for i in range(B):
152
+ code_sync = cd_short[i].repeat_interleave(num_rep[i], dim=0)
153
+ out_tensor[i, :len_long[i]-1, :] = code_sync
154
+
155
+ return out_tensor
156
+
157
+
158
+ def forward(self, cep_in, mask_long, codes_mask, num_rep, len_short,
159
+ tgt_spect, len_spect,
160
+ spk_emb):
161
+
162
+ cd_long = self.encoder_cd(cep_in, mask_long)
163
+ fb = filter_bank_mean(num_rep, codes_mask, cd_long.size(1))
164
+
165
+ cd_short = torch.bmm(fb.detach(), cd_long)
166
+
167
+ cd_short_sync = self.pad_sequences_rnn(cd_short, num_rep, len_spect)
168
+
169
+ spk_emb_1 = self.encoder_spk(spk_emb)
170
+
171
+ # text to speech
172
+ _, memory_tx, _ = self.encoder_tx(cd_short_sync.transpose(1,0), len_spect,
173
+ spk_emb)
174
+ memory_tx_spk = torch.cat((spk_emb_1.unsqueeze(0), memory_tx), dim=0)
175
+ self.decoder_sp.decoder.init_state(memory_tx_spk, None, None)
176
+ spect_out, gate_sp_out \
177
+ = self.decoder_sp(tgt_spect, len_spect, memory_tx_spk, len_spect+1)
178
+
179
+ return spect_out, gate_sp_out
180
+
181
+
182
+ def infer_onmt(self, cep_in, mask_long,
183
+ len_spect,
184
+ spk_emb):
185
+
186
+ cd_long = self.encoder_cd(cep_in, mask_long)
187
+
188
+ spk_emb_1 = self.encoder_spk(spk_emb)
189
+
190
+ # text to speech
191
+ _, memory_tx, _ = self.encoder_tx(cd_long.transpose(1,0), len_spect,
192
+ spk_emb)
193
+ memory_tx_spk = torch.cat((spk_emb_1.unsqueeze(0), memory_tx), dim=0)
194
+ self.decoder_sp.decoder.init_state(memory_tx_spk, None, None)
195
+ spect_output, len_spect_out, stop_sp_output \
196
+ = self.fast_dec_sp.infer(None, memory_tx_spk, len_spect+1,
197
+ self.decoder_sp.decoder,
198
+ self.decoder_sp.postnet)
199
+
200
+ return spect_output, len_spect_out
201
+
202
+
203
+
204
+ class Generator_2(nn.Module):
205
+ '''
206
+ async stage 2
207
+ '''
208
+ def __init__(self, hparams):
209
+ super().__init__()
210
+
211
+ self.encoder_cd = Encoder_Code_2(hparams)
212
+ self.encoder_tx = Encoder_Tx_Spk(hparams)
213
+ self.decoder_sp = Decoder_Sp(hparams)
214
+ self.encoder_spk = nn.Linear(hparams.dim_spk,
215
+ hparams.enc_rnn_size, bias=True)
216
+ self.fast_dec_sp = DecodeFunc_Sp(hparams, 'Sp')
217
+
218
+
219
+ def forward(self, cep_in, mask_long, codes_mask, num_rep, len_short,
220
+ tgt_spect, len_spect,
221
+ spk_emb):
222
+
223
+ cd_long = self.encoder_cd(cep_in, mask_long)
224
+ fb = filter_bank_mean(num_rep, codes_mask, cd_long.size(1))
225
+
226
+ cd_short = torch.bmm(fb.detach(), cd_long.detach())
227
+
228
+ spk_emb_1 = self.encoder_spk(spk_emb)
229
+
230
+ # text to speech
231
+ _, memory_tx, _ = self.encoder_tx(cd_short.transpose(1,0), len_short,
232
+ spk_emb)
233
+ memory_tx_spk = torch.cat((spk_emb_1.unsqueeze(0), memory_tx), dim=0)
234
+ self.decoder_sp.decoder.init_state(memory_tx_spk, None, None)
235
+ spect_out, gate_sp_out \
236
+ = self.decoder_sp(tgt_spect, len_spect, memory_tx_spk, len_short+1)
237
+
238
+ return spect_out, gate_sp_out
239
+
240
+
241
+ def infer_onmt(self, cep_in, mask_long, len_spect,
242
+ spk_emb):
243
+
244
+ cd_long = self.encoder_cd(cep_in, mask_long)
245
+
246
+ spk_emb_1 = self.encoder_spk(spk_emb)
247
+
248
+ # text to speech
249
+ _, memory_tx, _ = self.encoder_tx(cd_long.transpose(1,0), len_spect,
250
+ spk_emb)
251
+ memory_tx_spk = torch.cat((spk_emb_1.unsqueeze(0), memory_tx), dim=0)
252
+ self.decoder_sp.decoder.init_state(memory_tx_spk, None, None)
253
+ spect_output, len_spect_out, stop_sp_output \
254
+ = self.fast_dec_sp.infer(None, memory_tx_spk, len_spect+1,
255
+ self.decoder_sp.decoder,
256
+ self.decoder_sp.postnet)
257
+
258
+ return spect_output, len_spect_out
model_sea.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ from utils import ConvNorm, LinearNorm
7
+ from torch.nn.parameter import Parameter
8
+
9
+
10
+
11
+ class GroupNorm_Mask(nn.Module):
12
+ def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
13
+ super().__init__()
14
+
15
+ self.num_groups = num_groups
16
+ self.num_channels = num_channels
17
+ self.eps = eps
18
+ self.affine = affine
19
+ if self.affine:
20
+ self.weight = Parameter(torch.Tensor(num_channels))
21
+ self.bias = Parameter(torch.Tensor(num_channels))
22
+ else:
23
+ self.register_parameter('weight', None)
24
+ self.register_parameter('bias', None)
25
+ self.reset_parameters()
26
+
27
+ def reset_parameters(self):
28
+ if self.affine:
29
+ nn.init.ones_(self.weight)
30
+ nn.init.zeros_(self.bias)
31
+
32
+ def forward(self, x, mask):
33
+ B, C, L = x.size()
34
+ assert C % self.num_groups == 0
35
+
36
+ x = x.view(B, self.num_groups, C//self.num_groups, L)
37
+ mask = mask.view(B, 1, 1, L)
38
+ x = x * mask
39
+
40
+ mean = x.mean(dim=2, keepdim=True).sum(dim=3, keepdim=True) / mask.sum(dim=3, keepdim=True)
41
+ var = (((x - mean)**2)*mask).mean(dim=2, keepdim=True).sum(dim=3, keepdim=True) / mask.sum(dim=3, keepdim=True)
42
+
43
+ x = (x - mean) / (var + self.eps).sqrt()
44
+
45
+ x = x.view(B, C, L)
46
+
47
+ return x * self.weight.view(1,-1,1) + self.bias.view(1,-1,1)
48
+
49
+
50
+
51
+ class M43_Sequential(nn.Sequential):
52
+ def forward(self, inputs, mask):
53
+ inputs = self._modules['0'](inputs)
54
+ inputs = self._modules['1'](inputs, mask)
55
+ return inputs
56
+
57
+
58
+
59
+ class Encoder(nn.Module):
60
+ """Encoder module:
61
+ """
62
+ def __init__(self, hparams):
63
+ super(Encoder, self).__init__()
64
+
65
+ self.dim_freq = hparams.dim_freq_sea
66
+ self.dim_enc = hparams.dim_enc_sea
67
+ self.chs_grp = hparams.chs_grp
68
+ self.dim_neck = hparams.dim_neck_sea
69
+
70
+ convolutions = []
71
+ for i in range(5):
72
+ conv_layer = M43_Sequential(
73
+ ConvNorm(self.dim_freq if i==0 else self.dim_enc,
74
+ self.dim_enc,
75
+ kernel_size=1, stride=1,
76
+ padding=0,
77
+ dilation=1, w_init_gain='relu'),
78
+ GroupNorm_Mask(self.dim_enc//self.chs_grp, self.dim_enc))
79
+ convolutions.append(conv_layer)
80
+
81
+ conv_layer = M43_Sequential(
82
+ ConvNorm(self.dim_enc,
83
+ 128,
84
+ kernel_size=1, stride=1,
85
+ padding=0,
86
+ dilation=1, w_init_gain='relu'),
87
+ GroupNorm_Mask(128//self.chs_grp, 128))
88
+ convolutions.append(conv_layer)
89
+
90
+ conv_layer = M43_Sequential(
91
+ ConvNorm(128,
92
+ 32,
93
+ kernel_size=1, stride=1,
94
+ padding=0,
95
+ dilation=1, w_init_gain='relu'),
96
+ GroupNorm_Mask(32//self.chs_grp, 32))
97
+ convolutions.append(conv_layer)
98
+
99
+ conv_layer = M43_Sequential(
100
+ ConvNorm(32,
101
+ self.dim_neck,
102
+ kernel_size=1, stride=1,
103
+ padding=0,
104
+ dilation=1, w_init_gain='relu'),
105
+ GroupNorm_Mask(1, self.dim_neck))
106
+ convolutions.append(conv_layer)
107
+
108
+ self.convolutions = nn.ModuleList(convolutions)
109
+
110
+
111
+ def forward(self, x, mask):
112
+
113
+ for conv in self.convolutions:
114
+ x = F.relu(conv(x, mask))
115
+
116
+ codes = x.permute(0, 2, 1) * mask.unsqueeze(-1)
117
+
118
+ return codes
119
+
120
+
121
+
122
+ class Decoder(nn.Module):
123
+ """Decoder module:
124
+ """
125
+ def __init__(self, hparams):
126
+ super(Decoder, self).__init__()
127
+ self.dim_enc = hparams.dim_enc_sea
128
+ self.dim_emb = hparams.dim_spk
129
+ self.dim_freq = hparams.dim_freq_sp
130
+ self.dim_neck = hparams.dim_neck_sea
131
+
132
+ self.lstm = nn.LSTM(self.dim_neck+self.dim_emb,
133
+ 1024, 3, batch_first=True)
134
+
135
+ self.linear_projection = LinearNorm(1024, self.dim_freq)
136
+
137
+ def forward(self, x):
138
+
139
+ outputs = self.lstm(x)[0]
140
+
141
+ decoder_output = self.linear_projection(outputs)
142
+
143
+ return decoder_output
144
+
145
+
146
+
147
+
148
+ class Generator(nn.Module):
149
+ """Generator network."""
150
+ def __init__(self, hparams):
151
+ super(Generator, self).__init__()
152
+
153
+ self.encoder = Encoder(hparams)
154
+ self.decoder = Decoder(hparams)
155
+
156
+ def forward(self, x, c_trg):
157
+
158
+ x = x.transpose(2,1)
159
+ codes = self.encoder(x)
160
+
161
+ encoder_outputs = torch.cat((codes,
162
+ c_trg.unsqueeze(1).expand(-1,x.size(-1),-1)), dim=-1)
163
+ mel_outputs = self.decoder(encoder_outputs)
164
+
165
+ return mel_outputs
166
+
167
+ def encode(self, x, mask):
168
+ x = x.transpose(2,1)
169
+ codes = self.encoder(x, mask)
170
+ return codes
171
+
172
+ def decode(self, codes, c_trg):
173
+ encoder_outputs = torch.cat((codes,
174
+ c_trg.unsqueeze(1).expand(-1,codes.size(1),-1)), dim=-1)
175
+ mel_outputs = self.decoder(encoder_outputs)
176
+
177
+ return mel_outputs
178
+
179
+
180
+
181
+ class Encoder_2(nn.Module):
182
+ """Encoder module:
183
+ """
184
+ def __init__(self, hparams):
185
+ super().__init__()
186
+
187
+ self.dim_freq = hparams.dim_freq_sea
188
+ self.dim_enc = hparams.dim_enc_sea
189
+ self.chs_grp = hparams.chs_grp
190
+ self.dim_neck = hparams.dim_neck_sea
191
+
192
+ convolutions = []
193
+ for i in range(5):
194
+ conv_layer = M43_Sequential(
195
+ ConvNorm(self.dim_freq if i==0 else self.dim_enc,
196
+ self.dim_enc,
197
+ kernel_size=5, stride=1,
198
+ padding=2,
199
+ dilation=1, w_init_gain='relu'),
200
+ GroupNorm_Mask(self.dim_enc//self.chs_grp, self.dim_enc))
201
+ convolutions.append(conv_layer)
202
+
203
+ conv_layer = M43_Sequential(
204
+ ConvNorm(self.dim_enc,
205
+ 128,
206
+ kernel_size=5, stride=1,
207
+ padding=2,
208
+ dilation=1, w_init_gain='relu'),
209
+ GroupNorm_Mask(128//self.chs_grp, 128))
210
+ convolutions.append(conv_layer)
211
+
212
+ conv_layer = M43_Sequential(
213
+ ConvNorm(128,
214
+ 32,
215
+ kernel_size=5, stride=1,
216
+ padding=2,
217
+ dilation=1, w_init_gain='relu'),
218
+ GroupNorm_Mask(32//self.chs_grp, 32))
219
+ convolutions.append(conv_layer)
220
+
221
+ conv_layer = M43_Sequential(
222
+ ConvNorm(32,
223
+ self.dim_neck,
224
+ kernel_size=5, stride=1,
225
+ padding=2,
226
+ dilation=1, w_init_gain='linear'),
227
+ GroupNorm_Mask(1, self.dim_neck))
228
+ convolutions.append(conv_layer)
229
+
230
+ self.convolutions = nn.ModuleList(convolutions)
231
+
232
+
233
+ def forward(self, x, mask):
234
+
235
+ for i in range(len(self.convolutions)-1):
236
+ x = F.relu(self.convolutions[i](x, mask))
237
+
238
+ x = self.convolutions[-1](x, mask)
239
+
240
+ codes = x.permute(0, 2, 1) * mask.unsqueeze(-1)
241
+
242
+ return codes
onmt_modules/__init__.py ADDED
File without changes
onmt_modules/average_attn.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Average Attention module."""
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from .position_ffn import PositionwiseFeedForward
8
+
9
+
10
+ class AverageAttention(nn.Module):
11
+ """
12
+ Average Attention module from
13
+ "Accelerating Neural Transformer via an Average Attention Network"
14
+ :cite:`DBLP:journals/corr/abs-1805-00631`.
15
+
16
+ Args:
17
+ model_dim (int): the dimension of keys/values/queries,
18
+ must be divisible by head_count
19
+ dropout (float): dropout parameter
20
+ """
21
+
22
+ def __init__(self, model_dim, dropout=0.1, aan_useffn=False):
23
+ self.model_dim = model_dim
24
+ self.aan_useffn = aan_useffn
25
+ super(AverageAttention, self).__init__()
26
+ if aan_useffn:
27
+ self.average_layer = PositionwiseFeedForward(model_dim, model_dim,
28
+ dropout)
29
+ self.gating_layer = nn.Linear(model_dim * 2, model_dim * 2)
30
+
31
+ def cumulative_average_mask(self, batch_size, inputs_len, device):
32
+ """
33
+ Builds the mask to compute the cumulative average as described in
34
+ :cite:`DBLP:journals/corr/abs-1805-00631` -- Figure 3
35
+
36
+ Args:
37
+ batch_size (int): batch size
38
+ inputs_len (int): length of the inputs
39
+
40
+ Returns:
41
+ (FloatTensor):
42
+
43
+ * A Tensor of shape ``(batch_size, input_len, input_len)``
44
+ """
45
+
46
+ triangle = torch.tril(torch.ones(inputs_len, inputs_len,
47
+ dtype=torch.float, device=device))
48
+ weights = torch.ones(1, inputs_len, dtype=torch.float, device=device) \
49
+ / torch.arange(1, inputs_len + 1, dtype=torch.float, device=device)
50
+ mask = triangle * weights.transpose(0, 1)
51
+
52
+ return mask.unsqueeze(0).expand(batch_size, inputs_len, inputs_len)
53
+
54
+ def cumulative_average(self, inputs, mask_or_step,
55
+ layer_cache=None, step=None):
56
+ """
57
+ Computes the cumulative average as described in
58
+ :cite:`DBLP:journals/corr/abs-1805-00631` -- Equations (1) (5) (6)
59
+
60
+ Args:
61
+ inputs (FloatTensor): sequence to average
62
+ ``(batch_size, input_len, dimension)``
63
+ mask_or_step: if cache is set, this is assumed
64
+ to be the current step of the
65
+ dynamic decoding. Otherwise, it is the mask matrix
66
+ used to compute the cumulative average.
67
+ layer_cache: a dictionary containing the cumulative average
68
+ of the previous step.
69
+
70
+ Returns:
71
+ a tensor of the same shape and type as ``inputs``.
72
+ """
73
+
74
+ if layer_cache is not None:
75
+ step = mask_or_step
76
+ average_attention = (inputs + step *
77
+ layer_cache["prev_g"]) / (step + 1)
78
+ layer_cache["prev_g"] = average_attention
79
+ return average_attention
80
+ else:
81
+ mask = mask_or_step
82
+ return torch.matmul(mask.to(inputs.dtype), inputs)
83
+
84
+ def forward(self, inputs, mask=None, layer_cache=None, step=None):
85
+ """
86
+ Args:
87
+ inputs (FloatTensor): ``(batch_size, input_len, model_dim)``
88
+
89
+ Returns:
90
+ (FloatTensor, FloatTensor):
91
+
92
+ * gating_outputs ``(batch_size, input_len, model_dim)``
93
+ * average_outputs average attention
94
+ ``(batch_size, input_len, model_dim)``
95
+ """
96
+
97
+ batch_size = inputs.size(0)
98
+ inputs_len = inputs.size(1)
99
+ average_outputs = self.cumulative_average(
100
+ inputs, self.cumulative_average_mask(batch_size,
101
+ inputs_len, inputs.device)
102
+ if layer_cache is None else step, layer_cache=layer_cache)
103
+ if self.aan_useffn:
104
+ average_outputs = self.average_layer(average_outputs)
105
+ gating_outputs = self.gating_layer(torch.cat((inputs,
106
+ average_outputs), -1))
107
+ input_gate, forget_gate = torch.chunk(gating_outputs, 2, dim=2)
108
+ gating_outputs = torch.sigmoid(input_gate) * inputs + \
109
+ torch.sigmoid(forget_gate) * average_outputs
110
+
111
+ return gating_outputs, average_outputs
onmt_modules/decoder.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .misc import aeq
5
+
6
+
7
+ class DecoderBase(nn.Module):
8
+ """Abstract class for decoders.
9
+
10
+ Args:
11
+ attentional (bool): The decoder returns non-empty attention.
12
+ """
13
+
14
+ def __init__(self, attentional=True):
15
+ super(DecoderBase, self).__init__()
16
+ self.attentional = attentional
17
+
18
+ @classmethod
19
+ def from_opt(cls, opt, embeddings):
20
+ """Alternate constructor.
21
+
22
+ Subclasses should override this method.
23
+ """
24
+
25
+ raise NotImplementedError
onmt_modules/decoder_transformer.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of "Attention is All You Need"
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .decoder import DecoderBase
9
+ from .multi_headed_attn import MultiHeadedAttention
10
+ from .average_attn import AverageAttention
11
+ from .position_ffn import PositionwiseFeedForward
12
+ from .misc import sequence_mask
13
+
14
+
15
+ class TransformerDecoderLayer(nn.Module):
16
+ """Transformer Decoder layer block in Pre-Norm style.
17
+ Pre-Norm style is an improvement w.r.t. Original paper's Post-Norm style,
18
+ providing better converge speed and performance. This is also the actual
19
+ implementation in tensor2tensor and also avalable in fairseq.
20
+ See https://tunz.kr/post/4 and :cite:`DeeperTransformer`.
21
+
22
+ .. mermaid::
23
+
24
+ graph LR
25
+ %% "*SubLayer" can be self-attn, src-attn or feed forward block
26
+ A(input) --> B[Norm]
27
+ B --> C["*SubLayer"]
28
+ C --> D[Drop]
29
+ D --> E((+))
30
+ A --> E
31
+ E --> F(out)
32
+
33
+
34
+ Args:
35
+ d_model (int): the dimension of keys/values/queries in
36
+ :class:`MultiHeadedAttention`, also the input size of
37
+ the first-layer of the :class:`PositionwiseFeedForward`.
38
+ heads (int): the number of heads for MultiHeadedAttention.
39
+ d_ff (int): the second-layer of the :class:`PositionwiseFeedForward`.
40
+ dropout (float): dropout in residual, self-attn(dot) and feed-forward
41
+ attention_dropout (float): dropout in context_attn (and self-attn(avg))
42
+ self_attn_type (string): type of self-attention scaled-dot, average
43
+ max_relative_positions (int):
44
+ Max distance between inputs in relative positions representations
45
+ aan_useffn (bool): Turn on the FFN layer in the AAN decoder
46
+ full_context_alignment (bool):
47
+ whether enable an extra full context decoder forward for alignment
48
+ alignment_heads (int):
49
+ N. of cross attention heads to use for alignment guiding
50
+ """
51
+
52
+ def __init__(self, d_model, heads, d_ff, dropout, attention_dropout,
53
+ self_attn_type="scaled-dot", max_relative_positions=0,
54
+ aan_useffn=False, full_context_alignment=False,
55
+ alignment_heads=0):
56
+ super(TransformerDecoderLayer, self).__init__()
57
+
58
+ if self_attn_type == "scaled-dot":
59
+ self.self_attn = MultiHeadedAttention(
60
+ heads, d_model, dropout=attention_dropout,
61
+ max_relative_positions=max_relative_positions)
62
+ elif self_attn_type == "average":
63
+ self.self_attn = AverageAttention(d_model,
64
+ dropout=attention_dropout,
65
+ aan_useffn=aan_useffn)
66
+
67
+ self.context_attn = MultiHeadedAttention(
68
+ heads, d_model, dropout=attention_dropout)
69
+ self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
70
+ self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6)
71
+ self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6)
72
+ self.drop = nn.Dropout(dropout)
73
+ self.full_context_alignment = full_context_alignment
74
+ self.alignment_heads = alignment_heads
75
+
76
+ def forward(self, *args, **kwargs):
77
+ """ Extend `_forward` for (possibly) multiple decoder pass:
78
+ Always a default (future masked) decoder forward pass,
79
+ Possibly a second future aware decoder pass for joint learn
80
+ full context alignement, :cite:`garg2019jointly`.
81
+
82
+ Args:
83
+ * All arguments of _forward.
84
+ with_align (bool): whether return alignment attention.
85
+
86
+ Returns:
87
+ (FloatTensor, FloatTensor, FloatTensor or None):
88
+
89
+ * output ``(batch_size, T, model_dim)``
90
+ * top_attn ``(batch_size, T, src_len)``
91
+ * attn_align ``(batch_size, T, src_len)`` or None
92
+ """
93
+ with_align = kwargs.pop('with_align', False)
94
+ output, attns = self._forward(*args, **kwargs)
95
+ top_attn = attns[:, 0, :, :].contiguous()
96
+ attn_align = None
97
+ if with_align:
98
+ if self.full_context_alignment:
99
+ # return _, (B, Q_len, K_len)
100
+ _, attns = self._forward(*args, **kwargs, future=True)
101
+
102
+ if self.alignment_heads > 0:
103
+ attns = attns[:, :self.alignment_heads, :, :].contiguous()
104
+ # layer average attention across heads, get ``(B, Q, K)``
105
+ # Case 1: no full_context, no align heads -> layer avg baseline
106
+ # Case 2: no full_context, 1 align heads -> guided align
107
+ # Case 3: full_context, 1 align heads -> full cte guided align
108
+ attn_align = attns.mean(dim=1)
109
+ return output, top_attn, attn_align
110
+
111
+ def _forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask,
112
+ layer_cache=None, step=None, future=False):
113
+ """ A naive forward pass for transformer decoder.
114
+
115
+ # T: could be 1 in the case of stepwise decoding or tgt_len
116
+
117
+ Args:
118
+ inputs (FloatTensor): ``(batch_size, T, model_dim)``
119
+ memory_bank (FloatTensor): ``(batch_size, src_len, model_dim)``
120
+ src_pad_mask (LongTensor): ``(batch_size, 1, src_len)``
121
+ tgt_pad_mask (LongTensor): ``(batch_size, 1, T)``
122
+ layer_cache (dict or None): cached layer info when stepwise decode
123
+ step (int or None): stepwise decoding counter
124
+ future (bool): If set True, do not apply future_mask.
125
+
126
+ Returns:
127
+ (FloatTensor, FloatTensor):
128
+
129
+ * output ``(batch_size, T, model_dim)``
130
+ * attns ``(batch_size, head, T, src_len)``
131
+
132
+ """
133
+ dec_mask = None
134
+
135
+ if step is None:
136
+ tgt_len = tgt_pad_mask.size(-1)
137
+ if not future: # apply future_mask, result mask in (B, T, T)
138
+ future_mask = torch.ones(
139
+ [tgt_len, tgt_len],
140
+ device=tgt_pad_mask.device,
141
+ dtype=torch.uint8)
142
+ future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len)
143
+ # BoolTensor was introduced in pytorch 1.2
144
+ try:
145
+ future_mask = future_mask.bool()
146
+ except AttributeError:
147
+ pass
148
+ dec_mask = torch.gt(tgt_pad_mask + future_mask, 0)
149
+ else: # only mask padding, result mask in (B, 1, T)
150
+ dec_mask = tgt_pad_mask
151
+
152
+ input_norm = self.layer_norm_1(inputs)
153
+
154
+ if isinstance(self.self_attn, MultiHeadedAttention):
155
+ query, _ = self.self_attn(input_norm, input_norm, input_norm,
156
+ mask=dec_mask,
157
+ layer_cache=layer_cache,
158
+ attn_type="self")
159
+ elif isinstance(self.self_attn, AverageAttention):
160
+ query, _ = self.self_attn(input_norm, mask=dec_mask,
161
+ layer_cache=layer_cache, step=step)
162
+
163
+ query = self.drop(query) + inputs
164
+
165
+ query_norm = self.layer_norm_2(query)
166
+ mid, attns = self.context_attn(memory_bank, memory_bank, query_norm,
167
+ mask=src_pad_mask,
168
+ layer_cache=layer_cache,
169
+ attn_type="context")
170
+ output = self.feed_forward(self.drop(mid) + query)
171
+
172
+ return output, attns
173
+
174
+ def update_dropout(self, dropout, attention_dropout):
175
+ self.self_attn.update_dropout(attention_dropout)
176
+ self.context_attn.update_dropout(attention_dropout)
177
+ self.feed_forward.update_dropout(dropout)
178
+ self.drop.p = dropout
179
+
180
+
181
+ class TransformerDecoder(DecoderBase):
182
+ """The Transformer decoder from "Attention is All You Need".
183
+ :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`
184
+
185
+ .. mermaid::
186
+
187
+ graph BT
188
+ A[input]
189
+ B[multi-head self-attn]
190
+ BB[multi-head src-attn]
191
+ C[feed forward]
192
+ O[output]
193
+ A --> B
194
+ B --> BB
195
+ BB --> C
196
+ C --> O
197
+
198
+
199
+ Args:
200
+ num_layers (int): number of encoder layers.
201
+ d_model (int): size of the model
202
+ heads (int): number of heads
203
+ d_ff (int): size of the inner FF layer
204
+ copy_attn (bool): if using a separate copy attention
205
+ self_attn_type (str): type of self-attention scaled-dot, average
206
+ dropout (float): dropout in residual, self-attn(dot) and feed-forward
207
+ attention_dropout (float): dropout in context_attn (and self-attn(avg))
208
+ embeddings (onmt.modules.Embeddings):
209
+ embeddings to use, should have positional encodings
210
+ max_relative_positions (int):
211
+ Max distance between inputs in relative positions representations
212
+ aan_useffn (bool): Turn on the FFN layer in the AAN decoder
213
+ full_context_alignment (bool):
214
+ whether enable an extra full context decoder forward for alignment
215
+ alignment_layer (int): N° Layer to supervise with for alignment guiding
216
+ alignment_heads (int):
217
+ N. of cross attention heads to use for alignment guiding
218
+ """
219
+
220
+ def __init__(self, num_layers, d_model, heads, d_ff,
221
+ copy_attn, self_attn_type, dropout, attention_dropout,
222
+ embeddings, max_relative_positions, aan_useffn,
223
+ full_context_alignment, alignment_layer,
224
+ alignment_heads):
225
+ super(TransformerDecoder, self).__init__()
226
+
227
+ self.embeddings = embeddings
228
+
229
+ # Decoder State
230
+ self.state = {}
231
+
232
+ self.transformer_layers = nn.ModuleList(
233
+ [TransformerDecoderLayer(d_model, heads, d_ff, dropout,
234
+ attention_dropout, self_attn_type=self_attn_type,
235
+ max_relative_positions=max_relative_positions,
236
+ aan_useffn=aan_useffn,
237
+ full_context_alignment=full_context_alignment,
238
+ alignment_heads=alignment_heads)
239
+ for i in range(num_layers)])
240
+
241
+ # previously, there was a GlobalAttention module here for copy
242
+ # attention. But it was never actually used -- the "copy" attention
243
+ # just reuses the context attention.
244
+ self._copy = copy_attn
245
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
246
+
247
+ self.alignment_layer = alignment_layer
248
+
249
+ @classmethod
250
+ def from_opt(cls, opt, embeddings):
251
+ """Alternate constructor."""
252
+ return cls(
253
+ opt.dec_layers,
254
+ opt.dec_rnn_size,
255
+ opt.heads,
256
+ opt.transformer_ff,
257
+ opt.copy_attn,
258
+ opt.self_attn_type,
259
+ opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
260
+ opt.attention_dropout[0] if type(opt.attention_dropout)
261
+ is list else opt.dropout,
262
+ embeddings,
263
+ opt.max_relative_positions,
264
+ opt.aan_useffn,
265
+ opt.full_context_alignment,
266
+ opt.alignment_layer,
267
+ alignment_heads=opt.alignment_heads)
268
+
269
+ def init_state(self, src, memory_bank, enc_hidden):
270
+ """Initialize decoder state."""
271
+ self.state["src"] = src
272
+ self.state["cache"] = None
273
+
274
+ def map_state(self, fn):
275
+ def _recursive_map(struct, batch_dim=0):
276
+ for k, v in struct.items():
277
+ if v is not None:
278
+ if isinstance(v, dict):
279
+ _recursive_map(v)
280
+ else:
281
+ struct[k] = fn(v, batch_dim)
282
+
283
+ self.state["src"] = fn(self.state["src"], 1)
284
+ if self.state["cache"] is not None:
285
+ _recursive_map(self.state["cache"])
286
+
287
+ def detach_state(self):
288
+ self.state["src"] = self.state["src"].detach()
289
+
290
+ def forward(self, tgt, memory_bank, step=None, **kwargs):
291
+ """Decode, possibly stepwise."""
292
+ if step == 0:
293
+ self._init_cache(memory_bank)
294
+
295
+ tgt_words = tgt[:, :, 0].transpose(0, 1)
296
+
297
+ emb = self.embeddings(tgt, step=step)
298
+ assert emb.dim() == 3 # len x batch x embedding_dim
299
+
300
+ output = emb.transpose(0, 1).contiguous()
301
+ src_memory_bank = memory_bank.transpose(0, 1).contiguous()
302
+
303
+ pad_idx = self.embeddings.word_padding_idx
304
+ src_lens = kwargs["memory_lengths"]
305
+ src_max_len = self.state["src"].shape[0]
306
+ src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1)
307
+ tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt]
308
+
309
+ with_align = kwargs.pop('with_align', False)
310
+ attn_aligns = []
311
+
312
+ for i, layer in enumerate(self.transformer_layers):
313
+ layer_cache = self.state["cache"]["layer_{}".format(i)] \
314
+ if step is not None else None
315
+ output, attn, attn_align = layer(
316
+ output,
317
+ src_memory_bank,
318
+ src_pad_mask,
319
+ tgt_pad_mask,
320
+ layer_cache=layer_cache,
321
+ step=step,
322
+ with_align=with_align)
323
+ if attn_align is not None:
324
+ attn_aligns.append(attn_align)
325
+
326
+ output = self.layer_norm(output)
327
+ dec_outs = output.transpose(0, 1).contiguous()
328
+ attn = attn.transpose(0, 1).contiguous()
329
+
330
+ attns = {"std": attn}
331
+ if self._copy:
332
+ attns["copy"] = attn
333
+ if with_align:
334
+ attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)`
335
+ # attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg
336
+
337
+ # TODO change the way attns is returned dict => list or tuple (onnx)
338
+ return dec_outs, attns
339
+
340
+ def _init_cache(self, memory_bank):
341
+ self.state["cache"] = {}
342
+ batch_size = memory_bank.size(1)
343
+ depth = memory_bank.size(-1)
344
+
345
+ for i, layer in enumerate(self.transformer_layers):
346
+ layer_cache = {"memory_keys": None, "memory_values": None}
347
+ if isinstance(layer.self_attn, AverageAttention):
348
+ layer_cache["prev_g"] = torch.zeros((batch_size, 1, depth),
349
+ device=memory_bank.device)
350
+ else:
351
+ layer_cache["self_keys"] = None
352
+ layer_cache["self_values"] = None
353
+ self.state["cache"]["layer_{}".format(i)] = layer_cache
354
+
355
+ def update_dropout(self, dropout, attention_dropout):
356
+ self.embeddings.update_dropout(dropout)
357
+ for layer in self.transformer_layers:
358
+ layer.update_dropout(dropout, attention_dropout)
onmt_modules/embeddings.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Embeddings module """
2
+ import math
3
+ import warnings
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class PositionalEncoding(nn.Module):
10
+ """Sinusoidal positional encoding for non-recurrent neural networks.
11
+
12
+ Implementation based on "Attention Is All You Need"
13
+ :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`
14
+
15
+ Args:
16
+ dropout (float): dropout parameter
17
+ dim (int): embedding size
18
+ """
19
+
20
+ def __init__(self, dropout, dim, max_len=5000):
21
+ if dim % 2 != 0:
22
+ raise ValueError("Cannot use sin/cos positional encoding with "
23
+ "odd dim (got dim={:d})".format(dim))
24
+ pe = torch.zeros(max_len, dim)
25
+ position = torch.arange(0, max_len).unsqueeze(1)
26
+ div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *
27
+ -(math.log(10000.0) / dim)))
28
+ pe[:, 0::2] = torch.sin(position.float() * div_term)
29
+ pe[:, 1::2] = torch.cos(position.float() * div_term)
30
+ pe = pe.unsqueeze(1)
31
+ super(PositionalEncoding, self).__init__()
32
+ self.register_buffer('pe', pe)
33
+ self.dropout = nn.Dropout(p=dropout)
34
+ self.dim = dim
35
+
36
+ def forward(self, emb, step=None):
37
+ """Embed inputs.
38
+
39
+ Args:
40
+ emb (FloatTensor): Sequence of word vectors
41
+ ``(seq_len, batch_size, self.dim)``
42
+ step (int or NoneType): If stepwise (``seq_len = 1``), use
43
+ the encoding for this position.
44
+ """
45
+
46
+ emb = emb * math.sqrt(self.dim)
47
+ if step is None:
48
+ emb = emb + self.pe[:emb.size(0)]
49
+ else:
50
+ emb = emb + self.pe[step]
51
+ emb = self.dropout(emb)
52
+ return emb
onmt_modules/encoder.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base class for encoders and generic multi encoders."""
2
+
3
+ import torch.nn as nn
4
+
5
+ from .misc import aeq
6
+
7
+
8
+ class EncoderBase(nn.Module):
9
+ """
10
+ Base encoder class. Specifies the interface used by different encoder types
11
+ and required by :class:`onmt.Models.NMTModel`.
12
+
13
+ .. mermaid::
14
+
15
+ graph BT
16
+ A[Input]
17
+ subgraph RNN
18
+ C[Pos 1]
19
+ D[Pos 2]
20
+ E[Pos N]
21
+ end
22
+ F[Memory_Bank]
23
+ G[Final]
24
+ A-->C
25
+ A-->D
26
+ A-->E
27
+ C-->F
28
+ D-->F
29
+ E-->F
30
+ E-->G
31
+ """
32
+
33
+ @classmethod
34
+ def from_opt(cls, opt, embeddings=None):
35
+ raise NotImplementedError
36
+
37
+ def _check_args(self, src, lengths=None, hidden=None):
38
+ n_batch = src.size(1)
39
+ if lengths is not None:
40
+ n_batch_, = lengths.size()
41
+ aeq(n_batch, n_batch_)
42
+
43
+ def forward(self, src, lengths=None):
44
+ """
45
+ Args:
46
+ src (LongTensor):
47
+ padded sequences of sparse indices ``(src_len, batch, nfeat)``
48
+ lengths (LongTensor): length of each sequence ``(batch,)``
49
+
50
+
51
+ Returns:
52
+ (FloatTensor, FloatTensor):
53
+
54
+ * final encoder state, used to initialize decoder
55
+ * memory bank for attention, ``(src_len, batch, hidden)``
56
+ """
57
+
58
+ raise NotImplementedError
onmt_modules/encoder_transformer.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of "Attention is All You Need"
3
+ """
4
+
5
+ import torch.nn as nn
6
+
7
+ from .encoder import EncoderBase
8
+ from .multi_headed_attn import MultiHeadedAttention
9
+ from .position_ffn import PositionwiseFeedForward
10
+ from .misc import sequence_mask
11
+
12
+
13
+ class TransformerEncoderLayer(nn.Module):
14
+ """
15
+ A single layer of the transformer encoder.
16
+
17
+ Args:
18
+ d_model (int): the dimension of keys/values/queries in
19
+ MultiHeadedAttention, also the input size of
20
+ the first-layer of the PositionwiseFeedForward.
21
+ heads (int): the number of head for MultiHeadedAttention.
22
+ d_ff (int): the second-layer of the PositionwiseFeedForward.
23
+ dropout (float): dropout probability(0-1.0).
24
+ """
25
+
26
+ def __init__(self, d_model, heads, d_ff, dropout, attention_dropout,
27
+ max_relative_positions=0):
28
+ super(TransformerEncoderLayer, self).__init__()
29
+
30
+ self.self_attn = MultiHeadedAttention(
31
+ heads, d_model, dropout=attention_dropout,
32
+ max_relative_positions=max_relative_positions)
33
+ self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
34
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
35
+ self.dropout = nn.Dropout(dropout)
36
+
37
+ def forward(self, inputs, mask):
38
+ """
39
+ Args:
40
+ inputs (FloatTensor): ``(batch_size, src_len, model_dim)``
41
+ mask (LongTensor): ``(batch_size, 1, src_len)``
42
+
43
+ Returns:
44
+ (FloatTensor):
45
+
46
+ * outputs ``(batch_size, src_len, model_dim)``
47
+ """
48
+ input_norm = self.layer_norm(inputs)
49
+ context, _ = self.self_attn(input_norm, input_norm, input_norm,
50
+ mask=mask, attn_type="self")
51
+ out = self.dropout(context) + inputs
52
+ return self.feed_forward(out)
53
+
54
+ def update_dropout(self, dropout, attention_dropout):
55
+ self.self_attn.update_dropout(attention_dropout)
56
+ self.feed_forward.update_dropout(dropout)
57
+ self.dropout.p = dropout
58
+
59
+
60
+ class TransformerEncoder(EncoderBase):
61
+ """The Transformer encoder from "Attention is All You Need"
62
+ :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`
63
+
64
+ .. mermaid::
65
+
66
+ graph BT
67
+ A[input]
68
+ B[multi-head self-attn]
69
+ C[feed forward]
70
+ O[output]
71
+ A --> B
72
+ B --> C
73
+ C --> O
74
+
75
+ Args:
76
+ num_layers (int): number of encoder layers
77
+ d_model (int): size of the model
78
+ heads (int): number of heads
79
+ d_ff (int): size of the inner FF layer
80
+ dropout (float): dropout parameters
81
+ embeddings (onmt.modules.Embeddings):
82
+ embeddings to use, should have positional encodings
83
+
84
+ Returns:
85
+ (torch.FloatTensor, torch.FloatTensor):
86
+
87
+ * embeddings ``(src_len, batch_size, model_dim)``
88
+ * memory_bank ``(src_len, batch_size, model_dim)``
89
+ """
90
+
91
+ def __init__(self, num_layers, d_model, heads, d_ff, dropout,
92
+ attention_dropout, embeddings, max_relative_positions):
93
+ super(TransformerEncoder, self).__init__()
94
+
95
+ self.embeddings = embeddings
96
+ self.transformer = nn.ModuleList(
97
+ [TransformerEncoderLayer(
98
+ d_model, heads, d_ff, dropout, attention_dropout,
99
+ max_relative_positions=max_relative_positions)
100
+ for i in range(num_layers)])
101
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
102
+
103
+ @classmethod
104
+ def from_opt(cls, opt, embeddings):
105
+ """Alternate constructor."""
106
+ return cls(
107
+ opt.enc_layers,
108
+ opt.enc_rnn_size,
109
+ opt.heads,
110
+ opt.transformer_ff,
111
+ opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
112
+ opt.attention_dropout[0] if type(opt.attention_dropout)
113
+ is list else opt.attention_dropout,
114
+ embeddings,
115
+ opt.max_relative_positions)
116
+
117
+ def forward(self, src, lengths=None):
118
+ """See :func:`EncoderBase.forward()`"""
119
+ self._check_args(src, lengths)
120
+
121
+ emb = self.embeddings(src)
122
+
123
+ out = emb.transpose(0, 1).contiguous()
124
+ mask = ~sequence_mask(lengths).unsqueeze(1)
125
+ # Run the forward pass of every layer of the tranformer.
126
+ for layer in self.transformer:
127
+ out = layer(out, mask)
128
+ out = self.layer_norm(out)
129
+
130
+ return emb, out.transpose(0, 1).contiguous(), lengths
131
+
132
+ def update_dropout(self, dropout, attention_dropout):
133
+ self.embeddings.update_dropout(dropout)
134
+ for layer in self.transformer:
135
+ layer.update_dropout(dropout, attention_dropout)
onmt_modules/misc.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import random
5
+ import inspect
6
+ from itertools import islice, repeat
7
+ import os
8
+
9
+
10
+ def split_corpus(path, shard_size, default=None):
11
+ """yield a `list` containing `shard_size` line of `path`,
12
+ or repeatly generate `default` if `path` is None.
13
+ """
14
+ if path is not None:
15
+ return _split_corpus(path, shard_size)
16
+ else:
17
+ return repeat(default)
18
+
19
+
20
+ def _split_corpus(path, shard_size):
21
+ """Yield a `list` containing `shard_size` line of `path`.
22
+ """
23
+ with open(path, "rb") as f:
24
+ if shard_size <= 0:
25
+ yield f.readlines()
26
+ else:
27
+ while True:
28
+ shard = list(islice(f, shard_size))
29
+ if not shard:
30
+ break
31
+ yield shard
32
+
33
+
34
+ def aeq(*args):
35
+ """
36
+ Assert all arguments have the same value
37
+ """
38
+ arguments = (arg for arg in args)
39
+ first = next(arguments)
40
+ assert all(arg == first for arg in arguments), \
41
+ "Not all arguments have the same value: " + str(args)
42
+
43
+
44
+ def sequence_mask(lengths, max_len=None):
45
+ """
46
+ Creates a boolean mask from sequence lengths.
47
+ """
48
+ batch_size = lengths.numel()
49
+ max_len = max_len or lengths.max()
50
+ return (torch.arange(0, max_len, device=lengths.device)
51
+ .type_as(lengths)
52
+ .repeat(batch_size, 1)
53
+ .lt(lengths.unsqueeze(1)))
54
+
55
+
56
+ def tile(x, count, dim=0):
57
+ """
58
+ Tiles x on dimension dim count times.
59
+ """
60
+ perm = list(range(len(x.size())))
61
+ if dim != 0:
62
+ perm[0], perm[dim] = perm[dim], perm[0]
63
+ x = x.permute(perm).contiguous()
64
+ out_size = list(x.size())
65
+ out_size[0] *= count
66
+ batch = x.size(0)
67
+ x = x.view(batch, -1) \
68
+ .transpose(0, 1) \
69
+ .repeat(count, 1) \
70
+ .transpose(0, 1) \
71
+ .contiguous() \
72
+ .view(*out_size)
73
+ if dim != 0:
74
+ x = x.permute(perm).contiguous()
75
+ return x
76
+
77
+
78
+ def use_gpu(opt):
79
+ """
80
+ Creates a boolean if gpu used
81
+ """
82
+ return (hasattr(opt, 'gpu_ranks') and len(opt.gpu_ranks) > 0) or \
83
+ (hasattr(opt, 'gpu') and opt.gpu > -1)
84
+
85
+
86
+ def set_random_seed(seed, is_cuda):
87
+ """Sets the random seed."""
88
+ if seed > 0:
89
+ torch.manual_seed(seed)
90
+ # this one is needed for torchtext random call (shuffled iterator)
91
+ # in multi gpu it ensures datasets are read in the same order
92
+ random.seed(seed)
93
+ # some cudnn methods can be random even after fixing the seed
94
+ # unless you tell it to be deterministic
95
+ torch.backends.cudnn.deterministic = True
96
+
97
+ if is_cuda and seed > 0:
98
+ # These ensure same initialization in multi gpu mode
99
+ torch.cuda.manual_seed(seed)
100
+
101
+
102
+ def generate_relative_positions_matrix(length, max_relative_positions,
103
+ cache=False):
104
+ """Generate the clipped relative positions matrix
105
+ for a given length and maximum relative positions"""
106
+ if cache:
107
+ distance_mat = torch.arange(-length+1, 1, 1).unsqueeze(0)
108
+ else:
109
+ range_vec = torch.arange(length)
110
+ range_mat = range_vec.unsqueeze(-1).expand(-1, length).transpose(0, 1)
111
+ distance_mat = range_mat - range_mat.transpose(0, 1)
112
+ distance_mat_clipped = torch.clamp(distance_mat,
113
+ min=-max_relative_positions,
114
+ max=max_relative_positions)
115
+ # Shift values to be >= 0
116
+ final_mat = distance_mat_clipped + max_relative_positions
117
+ return final_mat
118
+
119
+
120
+ def relative_matmul(x, z, transpose):
121
+ """Helper function for relative positions attention."""
122
+ batch_size = x.shape[0]
123
+ heads = x.shape[1]
124
+ length = x.shape[2]
125
+ x_t = x.permute(2, 0, 1, 3)
126
+ x_t_r = x_t.reshape(length, heads * batch_size, -1)
127
+ if transpose:
128
+ z_t = z.transpose(1, 2)
129
+ x_tz_matmul = torch.matmul(x_t_r, z_t)
130
+ else:
131
+ x_tz_matmul = torch.matmul(x_t_r, z)
132
+ x_tz_matmul_r = x_tz_matmul.reshape(length, batch_size, heads, -1)
133
+ x_tz_matmul_r_t = x_tz_matmul_r.permute(1, 2, 0, 3)
134
+ return x_tz_matmul_r_t
135
+
136
+
137
+ def fn_args(fun):
138
+ """Returns the list of function arguments name."""
139
+ return inspect.getfullargspec(fun).args
140
+
141
+
142
+ def report_matrix(row_label, column_label, matrix):
143
+ header_format = "{:>10.10} " + "{:>10.7} " * len(row_label)
144
+ row_format = "{:>10.10} " + "{:>10.7f} " * len(row_label)
145
+ output = header_format.format("", *row_label) + '\n'
146
+ for word, row in zip(column_label, matrix):
147
+ max_index = row.index(max(row))
148
+ row_format = row_format.replace(
149
+ "{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
150
+ row_format = row_format.replace(
151
+ "{:*>10.7f} ", "{:>10.7f} ", max_index)
152
+ output += row_format.format(word, *row) + '\n'
153
+ row_format = "{:>10.10} " + "{:>10.7f} " * len(row_label)
154
+ return output
155
+
156
+
157
+ def check_model_config(model_config, root):
158
+ # we need to check the model path + any tokenizer path
159
+ for model in model_config["models"]:
160
+ model_path = os.path.join(root, model)
161
+ if not os.path.exists(model_path):
162
+ raise FileNotFoundError(
163
+ "{} from model {} does not exist".format(
164
+ model_path, model_config["id"]))
165
+ if "tokenizer" in model_config.keys():
166
+ if "params" in model_config["tokenizer"].keys():
167
+ for k, v in model_config["tokenizer"]["params"].items():
168
+ if k.endswith("path"):
169
+ tok_path = os.path.join(root, v)
170
+ if not os.path.exists(tok_path):
171
+ raise FileNotFoundError(
172
+ "{} from model {} does not exist".format(
173
+ tok_path, model_config["id"]))
onmt_modules/multi_headed_attn.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Multi-Head Attention module """
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .misc import generate_relative_positions_matrix,\
7
+ relative_matmul
8
+ # from onmt.utils.misc import aeq
9
+
10
+
11
+ class MultiHeadedAttention(nn.Module):
12
+ """Multi-Head Attention module from "Attention is All You Need"
13
+ :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`.
14
+
15
+ Similar to standard `dot` attention but uses
16
+ multiple attention distributions simulataneously
17
+ to select relevant items.
18
+
19
+ .. mermaid::
20
+
21
+ graph BT
22
+ A[key]
23
+ B[value]
24
+ C[query]
25
+ O[output]
26
+ subgraph Attn
27
+ D[Attn 1]
28
+ E[Attn 2]
29
+ F[Attn N]
30
+ end
31
+ A --> D
32
+ C --> D
33
+ A --> E
34
+ C --> E
35
+ A --> F
36
+ C --> F
37
+ D --> O
38
+ E --> O
39
+ F --> O
40
+ B --> O
41
+
42
+ Also includes several additional tricks.
43
+
44
+ Args:
45
+ head_count (int): number of parallel heads
46
+ model_dim (int): the dimension of keys/values/queries,
47
+ must be divisible by head_count
48
+ dropout (float): dropout parameter
49
+ """
50
+
51
+ def __init__(self, head_count, model_dim, dropout=0.1,
52
+ max_relative_positions=0):
53
+ assert model_dim % head_count == 0
54
+ self.dim_per_head = model_dim // head_count
55
+ self.model_dim = model_dim
56
+
57
+ super(MultiHeadedAttention, self).__init__()
58
+ self.head_count = head_count
59
+
60
+ self.linear_keys = nn.Linear(model_dim,
61
+ head_count * self.dim_per_head)
62
+ self.linear_values = nn.Linear(model_dim,
63
+ head_count * self.dim_per_head)
64
+ self.linear_query = nn.Linear(model_dim,
65
+ head_count * self.dim_per_head)
66
+ self.softmax = nn.Softmax(dim=-1)
67
+ self.dropout = nn.Dropout(dropout)
68
+ self.final_linear = nn.Linear(model_dim, model_dim)
69
+
70
+ self.max_relative_positions = max_relative_positions
71
+
72
+ if max_relative_positions > 0:
73
+ vocab_size = max_relative_positions * 2 + 1
74
+ self.relative_positions_embeddings = nn.Embedding(
75
+ vocab_size, self.dim_per_head)
76
+
77
+ def forward(self, key, value, query, mask=None,
78
+ layer_cache=None, attn_type=None):
79
+ """
80
+ Compute the context vector and the attention vectors.
81
+
82
+ Args:
83
+ key (FloatTensor): set of `key_len`
84
+ key vectors ``(batch, key_len, dim)``
85
+ value (FloatTensor): set of `key_len`
86
+ value vectors ``(batch, key_len, dim)``
87
+ query (FloatTensor): set of `query_len`
88
+ query vectors ``(batch, query_len, dim)``
89
+ mask: binary mask 1/0 indicating which keys have
90
+ zero / non-zero attention ``(batch, query_len, key_len)``
91
+ Returns:
92
+ (FloatTensor, FloatTensor):
93
+
94
+ * output context vectors ``(batch, query_len, dim)``
95
+ * Attention vector in heads ``(batch, head, query_len, key_len)``.
96
+ """
97
+
98
+ # CHECKS
99
+ # batch, k_len, d = key.size()
100
+ # batch_, k_len_, d_ = value.size()
101
+ # aeq(batch, batch_)
102
+ # aeq(k_len, k_len_)
103
+ # aeq(d, d_)
104
+ # batch_, q_len, d_ = query.size()
105
+ # aeq(batch, batch_)
106
+ # aeq(d, d_)
107
+ # aeq(self.model_dim % 8, 0)
108
+ # if mask is not None:
109
+ # batch_, q_len_, k_len_ = mask.size()
110
+ # aeq(batch_, batch)
111
+ # aeq(k_len_, k_len)
112
+ # aeq(q_len_ == q_len)
113
+ # END CHECKS
114
+
115
+ batch_size = key.size(0)
116
+ dim_per_head = self.dim_per_head
117
+ head_count = self.head_count
118
+ key_len = key.size(1)
119
+ query_len = query.size(1)
120
+
121
+ def shape(x):
122
+ """Projection."""
123
+ return x.view(batch_size, -1, head_count, dim_per_head) \
124
+ .transpose(1, 2)
125
+
126
+ def unshape(x):
127
+ """Compute context."""
128
+ return x.transpose(1, 2).contiguous() \
129
+ .view(batch_size, -1, head_count * dim_per_head)
130
+
131
+ # 1) Project key, value, and query.
132
+ if layer_cache is not None:
133
+ if attn_type == "self":
134
+ query, key, value = self.linear_query(query),\
135
+ self.linear_keys(query),\
136
+ self.linear_values(query)
137
+ key = shape(key)
138
+ value = shape(value)
139
+ if layer_cache["self_keys"] is not None:
140
+ key = torch.cat(
141
+ (layer_cache["self_keys"], key),
142
+ dim=2)
143
+ if layer_cache["self_values"] is not None:
144
+ value = torch.cat(
145
+ (layer_cache["self_values"], value),
146
+ dim=2)
147
+ layer_cache["self_keys"] = key
148
+ layer_cache["self_values"] = value
149
+ elif attn_type == "context":
150
+ query = self.linear_query(query)
151
+ if layer_cache["memory_keys"] is None:
152
+ key, value = self.linear_keys(key),\
153
+ self.linear_values(value)
154
+ key = shape(key)
155
+ value = shape(value)
156
+ else:
157
+ key, value = layer_cache["memory_keys"],\
158
+ layer_cache["memory_values"]
159
+ layer_cache["memory_keys"] = key
160
+ layer_cache["memory_values"] = value
161
+ else:
162
+ key = self.linear_keys(key)
163
+ value = self.linear_values(value)
164
+ query = self.linear_query(query)
165
+ key = shape(key)
166
+ value = shape(value)
167
+
168
+ if self.max_relative_positions > 0 and attn_type == "self":
169
+ key_len = key.size(2)
170
+ # 1 or key_len x key_len
171
+ relative_positions_matrix = generate_relative_positions_matrix(
172
+ key_len, self.max_relative_positions,
173
+ cache=True if layer_cache is not None else False)
174
+ # 1 or key_len x key_len x dim_per_head
175
+ relations_keys = self.relative_positions_embeddings(
176
+ relative_positions_matrix.to(key.device))
177
+ # 1 or key_len x key_len x dim_per_head
178
+ relations_values = self.relative_positions_embeddings(
179
+ relative_positions_matrix.to(key.device))
180
+
181
+ query = shape(query)
182
+
183
+ key_len = key.size(2)
184
+ query_len = query.size(2)
185
+
186
+ # 2) Calculate and scale scores.
187
+ query = query / math.sqrt(dim_per_head)
188
+ # batch x num_heads x query_len x key_len
189
+ query_key = torch.matmul(query, key.transpose(2, 3))
190
+
191
+ if self.max_relative_positions > 0 and attn_type == "self":
192
+ scores = query_key + relative_matmul(query, relations_keys, True)
193
+ else:
194
+ scores = query_key
195
+ scores = scores.float()
196
+
197
+ if mask is not None:
198
+ mask = mask.unsqueeze(1) # [B, 1, 1, T_values]
199
+ scores = scores.masked_fill(mask, -1e18)
200
+
201
+ # 3) Apply attention dropout and compute context vectors.
202
+ attn = self.softmax(scores).to(query.dtype)
203
+ drop_attn = self.dropout(attn)
204
+
205
+ context_original = torch.matmul(drop_attn, value)
206
+
207
+ if self.max_relative_positions > 0 and attn_type == "self":
208
+ context = unshape(context_original
209
+ + relative_matmul(drop_attn,
210
+ relations_values,
211
+ False))
212
+ else:
213
+ context = unshape(context_original)
214
+
215
+ output = self.final_linear(context)
216
+ # CHECK
217
+ # batch_, q_len_, d_ = output.size()
218
+ # aeq(q_len, q_len_)
219
+ # aeq(batch, batch_)
220
+ # aeq(d, d_)
221
+
222
+ # Return multi-head attn
223
+ attns = attn \
224
+ .view(batch_size, head_count,
225
+ query_len, key_len)
226
+
227
+ return output, attns
228
+
229
+ def update_dropout(self, dropout):
230
+ self.dropout.p = dropout
onmt_modules/position_ffn.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Position feed-forward network from "Attention is All You Need"."""
2
+
3
+ import torch.nn as nn
4
+
5
+
6
+ class PositionwiseFeedForward(nn.Module):
7
+ """ A two-layer Feed-Forward-Network with residual layer norm.
8
+
9
+ Args:
10
+ d_model (int): the size of input for the first-layer of the FFN.
11
+ d_ff (int): the hidden layer size of the second-layer
12
+ of the FNN.
13
+ dropout (float): dropout probability in :math:`[0, 1)`.
14
+ """
15
+
16
+ def __init__(self, d_model, d_ff, dropout=0.1):
17
+ super(PositionwiseFeedForward, self).__init__()
18
+ self.w_1 = nn.Linear(d_model, d_ff)
19
+ self.w_2 = nn.Linear(d_ff, d_model)
20
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
21
+ self.dropout_1 = nn.Dropout(dropout)
22
+ self.relu = nn.ReLU()
23
+ self.dropout_2 = nn.Dropout(dropout)
24
+
25
+ def forward(self, x):
26
+ """Layer definition.
27
+
28
+ Args:
29
+ x: ``(batch_size, input_len, model_dim)``
30
+
31
+ Returns:
32
+ (FloatTensor): Output ``(batch_size, input_len, model_dim)``.
33
+ """
34
+
35
+ inter = self.dropout_1(self.relu(self.w_1(self.layer_norm(x))))
36
+ output = self.dropout_2(self.w_2(inter))
37
+ return output + x
38
+
39
+ def update_dropout(self, dropout):
40
+ self.dropout_1.p = dropout
41
+ self.dropout_2.p = dropout
override_decoder.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from onmt_modules.decoder_transformer import TransformerDecoder
2
+ from onmt_modules.misc import sequence_mask
3
+
4
+
5
+ class OnmtDecoder_1(TransformerDecoder):
6
+ # overide forward
7
+ # without teacher forcing for stop
8
+ def forward(self, tgt, memory_bank, step=None, **kwargs):
9
+ """Decode, possibly stepwise."""
10
+ if step == 0:
11
+ self._init_cache(memory_bank)
12
+
13
+ if step is None:
14
+ tgt_lens = kwargs["tgt_lengths"]
15
+ else:
16
+ tgt_words = kwargs["tgt_words"]
17
+
18
+ emb = self.embeddings(tgt, step=step)
19
+ assert emb.dim() == 3 # len x batch x embedding_dim
20
+
21
+ output = emb.transpose(0, 1).contiguous()
22
+ src_memory_bank = memory_bank.transpose(0, 1).contiguous()
23
+
24
+ pad_idx = self.embeddings.word_padding_idx
25
+ src_lens = kwargs["memory_lengths"]
26
+ src_max_len = self.state["src"].shape[0]
27
+ src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1)
28
+ if step is None:
29
+ tgt_max_len = tgt_lens.max()
30
+ tgt_pad_mask = ~sequence_mask(tgt_lens, tgt_max_len).unsqueeze(1)
31
+ else:
32
+ tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1)
33
+
34
+ with_align = kwargs.pop('with_align', False)
35
+ attn_aligns = []
36
+
37
+ for i, layer in enumerate(self.transformer_layers):
38
+ layer_cache = self.state["cache"]["layer_{}".format(i)] \
39
+ if step is not None else None
40
+ output, attn, attn_align = layer(
41
+ output,
42
+ src_memory_bank,
43
+ src_pad_mask,
44
+ tgt_pad_mask,
45
+ layer_cache=layer_cache,
46
+ step=step,
47
+ with_align=with_align)
48
+ if attn_align is not None:
49
+ attn_aligns.append(attn_align)
50
+
51
+ output = self.layer_norm(output)
52
+ dec_outs = output.transpose(0, 1).contiguous()
53
+ attn = attn.transpose(0, 1).contiguous()
54
+
55
+ attns = {"std": attn}
56
+ if self._copy:
57
+ attns["copy"] = attn
58
+ if with_align:
59
+ attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)`
60
+ # attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg
61
+
62
+ # TODO change the way attns is returned dict => list or tuple (onnx)
63
+ return dec_outs, attns
prepare_train_data.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import numpy as np
4
+ import scipy.fftpack
5
+ import soundfile as sf
6
+ from utils import pySTFT
7
+ from scipy import signal
8
+ from librosa.filters import mel
9
+ from utils import butter_highpass
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from model_sea import Generator as Model
14
+ from hparams_sea import hparams
15
+
16
+
17
+ mel_basis = mel(16000, 1024, fmin=90, fmax=7600, n_mels=80).T
18
+ min_level = np.exp(-100 / 20 * np.log(10))
19
+ b, a = butter_highpass(30, 16000, order=5)
20
+
21
+ mfcc_mean, mfcc_std, dctmx = pickle.load(open('assets/mfcc_stats.pkl', 'rb'))
22
+ spk2emb = pickle.load(open('assets/spk2emb_82.pkl', 'rb'))
23
+
24
+ rootDir = "assets/vctk16-train-wav"
25
+ targetDir_sp = 'assets/vctk16-train-sp-mel'
26
+ targetDir_cep = 'assets/vctk16-train-cep-mel'
27
+ targetDir_cd = 'assets/vctk16-train-teacher'
28
+
29
+ device = 'cuda:0'
30
+
31
+ G = Model(hparams).eval().to(device)
32
+
33
+ g_checkpoint = torch.load('assets/sea.ckpt', map_location=lambda storage, loc: storage)
34
+ G.load_state_dict(g_checkpoint['model'], strict=True)
35
+
36
+
37
+ metadata = []
38
+ dirName, subdirList, _ = next(os.walk(rootDir))
39
+
40
+ for subdir in sorted(subdirList):
41
+ print(subdir)
42
+
43
+ if not os.path.exists(os.path.join(targetDir_sp, subdir)):
44
+ os.makedirs(os.path.join(targetDir_sp, subdir))
45
+ if not os.path.exists(os.path.join(targetDir_cep, subdir)):
46
+ os.makedirs(os.path.join(targetDir_cep, subdir))
47
+ if not os.path.exists(os.path.join(targetDir_cd, subdir)):
48
+ os.makedirs(os.path.join(targetDir_cd, subdir))
49
+
50
+ submeta = []
51
+ submeta.append(subdir)
52
+ submeta.append(spk2emb[subdir])
53
+
54
+ _,_, fileList = next(os.walk(os.path.join(dirName,subdir)))
55
+
56
+ for fileName in sorted(fileList):
57
+ x, fs = sf.read(os.path.join(dirName,subdir,fileName))
58
+ if x.shape[0] % 256 == 0:
59
+ x = np.concatenate((x, np.array([1e-06])), axis=0)
60
+ y = signal.filtfilt(b, a, x)
61
+ D = pySTFT(y * 0.96).T
62
+ D_mel = np.dot(D, mel_basis)
63
+ D_db = 20 * np.log10(np.maximum(min_level, D_mel))
64
+
65
+ # mel sp
66
+ S = (D_db + 80) / 100
67
+
68
+ # mel cep
69
+ cc_tmp = S.dot(dctmx)
70
+ cc_norm = (cc_tmp - mfcc_mean) / mfcc_std
71
+ S = np.clip(S, 0, 1)
72
+
73
+ # teacher code
74
+ cc_torch = torch.from_numpy(cc_norm[:,0:20].astype(np.float32)).unsqueeze(0).to(device)
75
+ with torch.no_grad():
76
+ codes = G.encode(cc_torch, torch.ones_like(cc_torch[:,:,0])).squeeze(0)
77
+
78
+ np.save(os.path.join(targetDir_cd, subdir, fileName[:-4]),
79
+ codes.cpu().numpy(), allow_pickle=False)
80
+ np.save(os.path.join(targetDir_sp, subdir, fileName[:-4]),
81
+ S.astype(np.float32), allow_pickle=False)
82
+ np.save(os.path.join(targetDir_cep, subdir, fileName[:-4]),
83
+ cc_norm.astype(np.float32), allow_pickle=False)
84
+
85
+ submeta.append(subdir+'/'+fileName[:-4]+'.npy')
86
+
87
+ metadata.append(submeta)
88
+
89
+ with open('./assets/train_vctk.meta', 'wb') as handle:
90
+ pickle.dump(metadata, handle)
solver_1.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import pickle
4
+ import datetime
5
+ import itertools
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from onmt_modules.misc import sequence_mask
11
+ from model_autopst import Generator_1 as Predictor
12
+
13
+
14
+
15
+ class Solver(object):
16
+
17
+ def __init__(self, data_loader, config, hparams):
18
+ """Initialize configurations."""
19
+
20
+
21
+ self.data_loader = data_loader
22
+ self.hparams = hparams
23
+ self.gate_threshold = hparams.gate_threshold
24
+
25
+ self.use_cuda = torch.cuda.is_available()
26
+ self.device = torch.device('cuda:{}'.format(config.device_id) if self.use_cuda else 'cpu')
27
+ self.num_iters = config.num_iters
28
+ self.log_step = config.log_step
29
+
30
+ # Build the model
31
+ self.build_model()
32
+
33
+
34
+ def build_model(self):
35
+
36
+ self.P = Predictor(self.hparams)
37
+
38
+ self.optimizer = torch.optim.Adam(self.P.parameters(), 0.0001, [0.9, 0.999])
39
+
40
+ self.P.to(self.device)
41
+
42
+ self.BCELoss = torch.nn.BCEWithLogitsLoss().to(self.device)
43
+
44
+
45
+ def train(self):
46
+ # Set data loader
47
+ data_loader = self.data_loader
48
+ data_iter = iter(data_loader)
49
+
50
+
51
+ # Print logs in specified order
52
+ keys = ['P/loss_tx2sp', 'P/loss_stop_sp']
53
+
54
+
55
+ # Start training.
56
+ print('Start training...')
57
+ start_time = time.time()
58
+ for i in range(self.num_iters):
59
+
60
+ try:
61
+ sp_real, cep_real, cd_real, _, num_rep_sync, len_real, _, len_short_sync, spk_emb = next(data_iter)
62
+ except:
63
+ data_iter = iter(data_loader)
64
+ sp_real, cep_real, cd_real, _, num_rep_sync, len_real, _, len_short_sync, spk_emb = next(data_iter)
65
+
66
+
67
+ sp_real = sp_real.to(self.device)
68
+ cep_real = cep_real.to(self.device)
69
+ cd_real = cd_real.to(self.device)
70
+ len_real = len_real.to(self.device)
71
+ spk_emb = spk_emb.to(self.device)
72
+ num_rep_sync = num_rep_sync.to(self.device)
73
+ len_short_sync = len_short_sync.to(self.device)
74
+
75
+
76
+ # real spect masks
77
+ mask_sp_real = ~sequence_mask(len_real, sp_real.size(1))
78
+ mask_long = (~mask_sp_real).float()
79
+
80
+ len_real_mask = torch.min(len_real + 10,
81
+ torch.full_like(len_real, sp_real.size(1)))
82
+ loss_tx2sp_mask = sequence_mask(len_real_mask, sp_real.size(1)).float().unsqueeze(-1)
83
+
84
+ # text input masks
85
+ codes_mask = sequence_mask(len_short_sync, num_rep_sync.size(1)).float()
86
+
87
+
88
+ # =================================================================================== #
89
+ # 2. Train #
90
+ # =================================================================================== #
91
+
92
+ self.P = self.P.train()
93
+
94
+
95
+ sp_real_sft = torch.zeros_like(sp_real)
96
+ sp_real_sft[:, 1:, :] = sp_real[:, :-1, :]
97
+
98
+
99
+ spect_pred, stop_pred_sp = self.P(cep_real.transpose(2,1),
100
+ mask_long,
101
+ codes_mask,
102
+ num_rep_sync,
103
+ len_short_sync+1,
104
+ sp_real_sft.transpose(1,0),
105
+ len_real+1,
106
+ spk_emb)
107
+
108
+
109
+ loss_tx2sp = (F.mse_loss(spect_pred.permute(1,0,2), sp_real, reduction='none')
110
+ * loss_tx2sp_mask).sum() / loss_tx2sp_mask.sum()
111
+
112
+ loss_stop_sp = self.BCELoss(stop_pred_sp.squeeze(-1).t(), mask_sp_real.float())
113
+
114
+ loss_total = loss_tx2sp + loss_stop_sp
115
+
116
+ # Backward and optimize
117
+ self.optimizer.zero_grad()
118
+ loss_total.backward()
119
+ self.optimizer.step()
120
+
121
+
122
+ # Logging
123
+ loss = {}
124
+ loss['P/loss_tx2sp'] = loss_tx2sp.item()
125
+ loss['P/loss_stop_sp'] = loss_stop_sp.item()
126
+
127
+
128
+ # =================================================================================== #
129
+ # 4. Miscellaneous #
130
+ # =================================================================================== #
131
+
132
+ # Print out training information
133
+ if (i+1) % self.log_step == 0:
134
+ et = time.time() - start_time
135
+ et = str(datetime.timedelta(seconds=et))[:-7]
136
+ log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
137
+ for tag in keys:
138
+ log += ", {}: {:.8f}".format(tag, loss[tag])
139
+ print(log)
140
+
141
+
142
+ # Save model checkpoints.
143
+ if (i+1) % 10000 == 0:
144
+ torch.save({'model': self.P.state_dict(),
145
+ 'optimizer': self.optimizer.state_dict()}, f'./assets/{i+1}-A.ckpt')
146
+ print('Saved model checkpoints into assets ...')
solver_2.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import pickle
4
+ import datetime
5
+ import itertools
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from onmt_modules.misc import sequence_mask
11
+ from model_autopst import Generator_2 as Predictor
12
+
13
+
14
+
15
+ class Solver(object):
16
+
17
+ def __init__(self, data_loader, config, hparams):
18
+ """Initialize configurations."""
19
+
20
+
21
+ self.data_loader = data_loader
22
+ self.hparams = hparams
23
+ self.gate_threshold = hparams.gate_threshold
24
+
25
+ self.use_cuda = torch.cuda.is_available()
26
+ self.device = torch.device('cuda:{}'.format(config.device_id) if self.use_cuda else 'cpu')
27
+ self.num_iters = config.num_iters
28
+ self.log_step = config.log_step
29
+
30
+ # Build the model
31
+ self.build_model()
32
+
33
+
34
+ def build_model(self):
35
+
36
+ self.P = Predictor(self.hparams)
37
+ self.freeze_layers(self.P.encoder_cd)
38
+
39
+ self.optimizer = torch.optim.Adam(self.P.parameters(), 0.0001, [0.9, 0.999])
40
+
41
+ self.P.to(self.device)
42
+
43
+ self.BCELoss = torch.nn.BCEWithLogitsLoss().to(self.device)
44
+
45
+
46
+ checkpoint = torch.load(self.hparams.pretrained_path,
47
+ map_location=lambda storage, loc: storage)
48
+
49
+ self.P.load_state_dict(checkpoint['model'], strict=True)
50
+ print('Loaded pretrained encoder .........................................')
51
+
52
+
53
+ def freeze_layers(self, layer):
54
+ print('Fixing layers!')
55
+ for param in layer.parameters():
56
+ param.requires_grad = False
57
+
58
+
59
+ def train(self):
60
+ # Set data loader
61
+ data_loader = self.data_loader
62
+ data_iter = iter(data_loader)
63
+
64
+
65
+ # Print logs in specified order
66
+ keys = ['P/loss_tx2sp', 'P/loss_stop_sp']
67
+
68
+
69
+ # Start training.
70
+ print('Start training...')
71
+ start_time = time.time()
72
+ for i in range(self.num_iters):
73
+
74
+ try:
75
+ sp_real, cep_real, cd_real, num_rep, _, len_real, len_short, _, spk_emb = next(data_iter)
76
+ except:
77
+ data_iter = iter(data_loader)
78
+ sp_real, cep_real, cd_real, num_rep, _, len_real, len_short, _, spk_emb = next(data_iter)
79
+
80
+
81
+ sp_real = sp_real.to(self.device)
82
+ cep_real = cep_real.to(self.device)
83
+ cd_real = cd_real.to(self.device)
84
+ len_real = len_real.to(self.device)
85
+ spk_emb = spk_emb.to(self.device)
86
+ num_rep = num_rep.to(self.device)
87
+ len_short = len_short.to(self.device)
88
+
89
+
90
+ # real spect masks
91
+ mask_sp_real = ~sequence_mask(len_real, sp_real.size(1))
92
+ mask_long = (~mask_sp_real).float()
93
+
94
+ len_real_mask = torch.min(len_real + 10,
95
+ torch.full_like(len_real, sp_real.size(1)))
96
+ loss_tx2sp_mask = sequence_mask(len_real_mask, sp_real.size(1)).float().unsqueeze(-1)
97
+
98
+ # text input masks
99
+ codes_mask = sequence_mask(len_short, num_rep.size(1)).float()
100
+
101
+
102
+ # =================================================================================== #
103
+ # 2. Train #
104
+ # =================================================================================== #
105
+
106
+ self.P = self.P.train()
107
+
108
+
109
+ sp_real_sft = torch.zeros_like(sp_real)
110
+ sp_real_sft[:, 1:, :] = sp_real[:, :-1, :]
111
+
112
+
113
+ spect_pred, stop_pred_sp = self.P(cep_real.transpose(2,1),
114
+ mask_long,
115
+ codes_mask,
116
+ num_rep,
117
+ len_short+1,
118
+ sp_real_sft.transpose(1,0),
119
+ len_real+1,
120
+ spk_emb)
121
+
122
+
123
+ loss_tx2sp = (F.mse_loss(spect_pred.permute(1,0,2), sp_real, reduction='none')
124
+ * loss_tx2sp_mask).sum() / loss_tx2sp_mask.sum()
125
+
126
+ loss_stop_sp = self.BCELoss(stop_pred_sp.squeeze(-1).t(), mask_sp_real.float())
127
+
128
+ loss_total = loss_tx2sp + loss_stop_sp
129
+
130
+ # Backward and optimize
131
+ self.optimizer.zero_grad()
132
+ loss_total.backward()
133
+ self.optimizer.step()
134
+
135
+
136
+ # Logging
137
+ loss = {}
138
+ loss['P/loss_tx2sp'] = loss_tx2sp.item()
139
+ loss['P/loss_stop_sp'] = loss_stop_sp.item()
140
+
141
+
142
+ # =================================================================================== #
143
+ # 4. Miscellaneous #
144
+ # =================================================================================== #
145
+
146
+ # Print out training information
147
+ if (i+1) % self.log_step == 0:
148
+ et = time.time() - start_time
149
+ et = str(datetime.timedelta(seconds=et))[:-7]
150
+ log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
151
+ for tag in keys:
152
+ log += ", {}: {:.8f}".format(tag, loss[tag])
153
+ print(log)
154
+
155
+
156
+ # Save model checkpoints.
157
+ if (i+1) % 10000 == 0:
158
+ torch.save({'model': self.P.state_dict(),
159
+ 'optimizer': self.optimizer.state_dict()}, f'./assets/{i+1}-B.ckpt')
160
+ print('Saved model checkpoints into assets ...')
tfcompat/__init__.py ADDED
File without changes
tfcompat/hparam.py ADDED
@@ -0,0 +1,726 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Hyperparameter values."""
16
+ from __future__ import absolute_import
17
+ from __future__ import division
18
+ from __future__ import print_function
19
+
20
+ import json
21
+ import numbers
22
+ import re
23
+
24
+ import six
25
+
26
+ ## from tensorflow.contrib.training.python.training import hparam_pb2
27
+ ## from tensorflow.python.framework import ops
28
+ ## from tensorflow.python.util import compat
29
+ ## from tensorflow.python.util import deprecation
30
+
31
+ # Define the regular expression for parsing a single clause of the input
32
+ # (delimited by commas). A legal clause looks like:
33
+ # <variable name>[<index>]? = <rhs>
34
+ # where <rhs> is either a single token or [] enclosed list of tokens.
35
+ # For example: "var[1] = a" or "x = [1,2,3]"
36
+ PARAM_RE = re.compile(r"""
37
+ (?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
38
+ (\[\s*(?P<index>\d+)\s*\])? # (optional) index: "1" or None
39
+ \s*=\s*
40
+ ((?P<val>[^,\[]*) # single value: "a" or None
41
+ |
42
+ \[(?P<vals>[^\]]*)\]) # list of values: None or "1,2,3"
43
+ ($|,\s*)""", re.VERBOSE)
44
+
45
+
46
+ def _parse_fail(name, var_type, value, values):
47
+ """Helper function for raising a value error for bad assignment."""
48
+ raise ValueError(
49
+ 'Could not parse hparam \'%s\' of type \'%s\' with value \'%s\' in %s' %
50
+ (name, var_type.__name__, value, values))
51
+
52
+
53
+ def _reuse_fail(name, values):
54
+ """Helper function for raising a value error for reuse of name."""
55
+ raise ValueError('Multiple assignments to variable \'%s\' in %s' % (name,
56
+ values))
57
+
58
+
59
+ def _process_scalar_value(name, parse_fn, var_type, m_dict, values,
60
+ results_dictionary):
61
+ """Update results_dictionary with a scalar value.
62
+
63
+ Used to update the results_dictionary to be returned by parse_values when
64
+ encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".)
65
+
66
+ Mutates results_dictionary.
67
+
68
+ Args:
69
+ name: Name of variable in assignment ("s" or "arr").
70
+ parse_fn: Function for parsing the actual value.
71
+ var_type: Type of named variable.
72
+ m_dict: Dictionary constructed from regex parsing.
73
+ m_dict['val']: RHS value (scalar)
74
+ m_dict['index']: List index value (or None)
75
+ values: Full expression being parsed
76
+ results_dictionary: The dictionary being updated for return by the parsing
77
+ function.
78
+
79
+ Raises:
80
+ ValueError: If the name has already been used.
81
+ """
82
+ try:
83
+ parsed_value = parse_fn(m_dict['val'])
84
+ except ValueError:
85
+ _parse_fail(name, var_type, m_dict['val'], values)
86
+
87
+ # If no index is provided
88
+ if not m_dict['index']:
89
+ if name in results_dictionary:
90
+ _reuse_fail(name, values)
91
+ results_dictionary[name] = parsed_value
92
+ else:
93
+ if name in results_dictionary:
94
+ # The name has already been used as a scalar, then it
95
+ # will be in this dictionary and map to a non-dictionary.
96
+ if not isinstance(results_dictionary.get(name), dict):
97
+ _reuse_fail(name, values)
98
+ else:
99
+ results_dictionary[name] = {}
100
+
101
+ index = int(m_dict['index'])
102
+ # Make sure the index position hasn't already been assigned a value.
103
+ if index in results_dictionary[name]:
104
+ _reuse_fail('{}[{}]'.format(name, index), values)
105
+ results_dictionary[name][index] = parsed_value
106
+
107
+
108
+ def _process_list_value(name, parse_fn, var_type, m_dict, values,
109
+ results_dictionary):
110
+ """Update results_dictionary from a list of values.
111
+
112
+ Used to update results_dictionary to be returned by parse_values when
113
+ encountering a clause with a list RHS (e.g. "arr=[1,2,3]".)
114
+
115
+ Mutates results_dictionary.
116
+
117
+ Args:
118
+ name: Name of variable in assignment ("arr").
119
+ parse_fn: Function for parsing individual values.
120
+ var_type: Type of named variable.
121
+ m_dict: Dictionary constructed from regex parsing.
122
+ m_dict['val']: RHS value (scalar)
123
+ values: Full expression being parsed
124
+ results_dictionary: The dictionary being updated for return by the parsing
125
+ function.
126
+
127
+ Raises:
128
+ ValueError: If the name has an index or the values cannot be parsed.
129
+ """
130
+ if m_dict['index'] is not None:
131
+ raise ValueError('Assignment of a list to a list index.')
132
+ elements = filter(None, re.split('[ ,]', m_dict['vals']))
133
+ # Make sure the name hasn't already been assigned a value
134
+ if name in results_dictionary:
135
+ raise _reuse_fail(name, values)
136
+ try:
137
+ results_dictionary[name] = [parse_fn(e) for e in elements]
138
+ except ValueError:
139
+ _parse_fail(name, var_type, m_dict['vals'], values)
140
+
141
+
142
+ def _cast_to_type_if_compatible(name, param_type, value):
143
+ """Cast hparam to the provided type, if compatible.
144
+
145
+ Args:
146
+ name: Name of the hparam to be cast.
147
+ param_type: The type of the hparam.
148
+ value: The value to be cast, if compatible.
149
+
150
+ Returns:
151
+ The result of casting `value` to `param_type`.
152
+
153
+ Raises:
154
+ ValueError: If the type of `value` is not compatible with param_type.
155
+ * If `param_type` is a string type, but `value` is not.
156
+ * If `param_type` is a boolean, but `value` is not, or vice versa.
157
+ * If `param_type` is an integer type, but `value` is not.
158
+ * If `param_type` is a float type, but `value` is not a numeric type.
159
+ """
160
+ fail_msg = (
161
+ "Could not cast hparam '%s' of type '%s' from value %r" %
162
+ (name, param_type, value))
163
+
164
+ # Some callers use None, for which we can't do any casting/checking. :(
165
+ if issubclass(param_type, type(None)):
166
+ return value
167
+
168
+ # Avoid converting a non-string type to a string.
169
+ if (issubclass(param_type, (six.string_types, six.binary_type)) and
170
+ not isinstance(value, (six.string_types, six.binary_type))):
171
+ raise ValueError(fail_msg)
172
+
173
+ # Avoid converting a number or string type to a boolean or vice versa.
174
+ if issubclass(param_type, bool) != isinstance(value, bool):
175
+ raise ValueError(fail_msg)
176
+
177
+ # Avoid converting float to an integer (the reverse is fine).
178
+ if (issubclass(param_type, numbers.Integral) and
179
+ not isinstance(value, numbers.Integral)):
180
+ raise ValueError(fail_msg)
181
+
182
+ # Avoid converting a non-numeric type to a numeric type.
183
+ if (issubclass(param_type, numbers.Number) and
184
+ not isinstance(value, numbers.Number)):
185
+ raise ValueError(fail_msg)
186
+
187
+ return param_type(value)
188
+
189
+
190
+ def parse_values(values, type_map):
191
+ """Parses hyperparameter values from a string into a python map.
192
+
193
+ `values` is a string containing comma-separated `name=value` pairs.
194
+ For each pair, the value of the hyperparameter named `name` is set to
195
+ `value`.
196
+
197
+ If a hyperparameter name appears multiple times in `values`, a ValueError
198
+ is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2').
199
+
200
+ If a hyperparameter name in both an index assignment and scalar assignment,
201
+ a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1').
202
+
203
+ The hyperparameter name may contain '.' symbols, which will result in an
204
+ attribute name that is only accessible through the getattr and setattr
205
+ functions. (And must be first explicit added through add_hparam.)
206
+
207
+ WARNING: Use of '.' in your variable names is allowed, but is not well
208
+ supported and not recommended.
209
+
210
+ The `value` in `name=value` must follows the syntax according to the
211
+ type of the parameter:
212
+
213
+ * Scalar integer: A Python-parsable integer point value. E.g.: 1,
214
+ 100, -12.
215
+ * Scalar float: A Python-parsable floating point value. E.g.: 1.0,
216
+ -.54e89.
217
+ * Boolean: Either true or false.
218
+ * Scalar string: A non-empty sequence of characters, excluding comma,
219
+ spaces, and square brackets. E.g.: foo, bar_1.
220
+ * List: A comma separated list of scalar values of the parameter type
221
+ enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low].
222
+
223
+ When index assignment is used, the corresponding type_map key should be the
224
+ list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not
225
+ "arr[1]").
226
+
227
+ Args:
228
+ values: String. Comma separated list of `name=value` pairs where
229
+ 'value' must follow the syntax described above.
230
+ type_map: A dictionary mapping hyperparameter names to types. Note every
231
+ parameter name in values must be a key in type_map. The values must
232
+ conform to the types indicated, where a value V is said to conform to a
233
+ type T if either V has type T, or V is a list of elements of type T.
234
+ Hence, for a multidimensional parameter 'x' taking float values,
235
+ 'x=[0.1,0.2]' will parse successfully if type_map['x'] = float.
236
+
237
+ Returns:
238
+ A python map mapping each name to either:
239
+ * A scalar value.
240
+ * A list of scalar values.
241
+ * A dictionary mapping index numbers to scalar values.
242
+ (e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}")
243
+
244
+ Raises:
245
+ ValueError: If there is a problem with input.
246
+ * If `values` cannot be parsed.
247
+ * If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]').
248
+ * If the same rvalue is assigned two different values (e.g. 'a=1,a=2',
249
+ 'a[1]=1,a[1]=2', or 'a=1,a=[1]')
250
+ """
251
+ results_dictionary = {}
252
+ pos = 0
253
+ while pos < len(values):
254
+ m = PARAM_RE.match(values, pos)
255
+ if not m:
256
+ raise ValueError('Malformed hyperparameter value: %s' % values[pos:])
257
+ # Check that there is a comma between parameters and move past it.
258
+ pos = m.end()
259
+ # Parse the values.
260
+ m_dict = m.groupdict()
261
+ name = m_dict['name']
262
+ if name not in type_map:
263
+ raise ValueError('Unknown hyperparameter type for %s' % name)
264
+ type_ = type_map[name]
265
+
266
+ # Set up correct parsing function (depending on whether type_ is a bool)
267
+ if type_ == bool:
268
+
269
+ def parse_bool(value):
270
+ if value in ['true', 'True']:
271
+ return True
272
+ elif value in ['false', 'False']:
273
+ return False
274
+ else:
275
+ try:
276
+ return bool(int(value))
277
+ except ValueError:
278
+ _parse_fail(name, type_, value, values)
279
+
280
+ parse = parse_bool
281
+ else:
282
+ parse = type_
283
+
284
+ # If a singe value is provided
285
+ if m_dict['val'] is not None:
286
+ _process_scalar_value(name, parse, type_, m_dict, values,
287
+ results_dictionary)
288
+
289
+ # If the assigned value is a list:
290
+ elif m_dict['vals'] is not None:
291
+ _process_list_value(name, parse, type_, m_dict, values,
292
+ results_dictionary)
293
+
294
+ else: # Not assigned a list or value
295
+ _parse_fail(name, type_, '', values)
296
+
297
+ return results_dictionary
298
+
299
+
300
+ class HParams(object):
301
+ """Class to hold a set of hyperparameters as name-value pairs.
302
+
303
+ A `HParams` object holds hyperparameters used to build and train a model,
304
+ such as the number of hidden units in a neural net layer or the learning rate
305
+ to use when training.
306
+
307
+ You first create a `HParams` object by specifying the names and values of the
308
+ hyperparameters.
309
+
310
+ To make them easily accessible the parameter names are added as direct
311
+ attributes of the class. A typical usage is as follows:
312
+
313
+ ```python
314
+ # Create a HParams object specifying names and values of the model
315
+ # hyperparameters:
316
+ hparams = HParams(learning_rate=0.1, num_hidden_units=100)
317
+
318
+ # The hyperparameter are available as attributes of the HParams object:
319
+ hparams.learning_rate ==> 0.1
320
+ hparams.num_hidden_units ==> 100
321
+ ```
322
+
323
+ Hyperparameters have type, which is inferred from the type of their value
324
+ passed at construction type. The currently supported types are: integer,
325
+ float, boolean, string, and list of integer, float, boolean, or string.
326
+
327
+ You can override hyperparameter values by calling the
328
+ [`parse()`](#HParams.parse) method, passing a string of comma separated
329
+ `name=value` pairs. This is intended to make it possible to override
330
+ any hyperparameter values from a single command-line flag to which
331
+ the user passes 'hyper-param=value' pairs. It avoids having to define
332
+ one flag for each hyperparameter.
333
+
334
+ The syntax expected for each value depends on the type of the parameter.
335
+ See `parse()` for a description of the syntax.
336
+
337
+ Example:
338
+
339
+ ```python
340
+ # Define a command line flag to pass name=value pairs.
341
+ # For example using argparse:
342
+ import argparse
343
+ parser = argparse.ArgumentParser(description='Train my model.')
344
+ parser.add_argument('--hparams', type=str,
345
+ help='Comma separated list of "name=value" pairs.')
346
+ args = parser.parse_args()
347
+ ...
348
+ def my_program():
349
+ # Create a HParams object specifying the names and values of the
350
+ # model hyperparameters:
351
+ hparams = tf.HParams(learning_rate=0.1, num_hidden_units=100,
352
+ activations=['relu', 'tanh'])
353
+
354
+ # Override hyperparameters values by parsing the command line
355
+ hparams.parse(args.hparams)
356
+
357
+ # If the user passed `--hparams=learning_rate=0.3` on the command line
358
+ # then 'hparams' has the following attributes:
359
+ hparams.learning_rate ==> 0.3
360
+ hparams.num_hidden_units ==> 100
361
+ hparams.activations ==> ['relu', 'tanh']
362
+
363
+ # If the hyperparameters are in json format use parse_json:
364
+ hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}')
365
+ ```
366
+ """
367
+
368
+ _HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks.
369
+
370
+ def __init__(self, hparam_def=None, model_structure=None, **kwargs):
371
+ """Create an instance of `HParams` from keyword arguments.
372
+
373
+ The keyword arguments specify name-values pairs for the hyperparameters.
374
+ The parameter types are inferred from the type of the values passed.
375
+
376
+ The parameter names are added as attributes of `HParams` object, so they
377
+ can be accessed directly with the dot notation `hparams._name_`.
378
+
379
+ Example:
380
+
381
+ ```python
382
+ # Define 3 hyperparameters: 'learning_rate' is a float parameter,
383
+ # 'num_hidden_units' an integer parameter, and 'activation' a string
384
+ # parameter.
385
+ hparams = tf.HParams(
386
+ learning_rate=0.1, num_hidden_units=100, activation='relu')
387
+
388
+ hparams.activation ==> 'relu'
389
+ ```
390
+
391
+ Note that a few names are reserved and cannot be used as hyperparameter
392
+ names. If you use one of the reserved name the constructor raises a
393
+ `ValueError`.
394
+
395
+ Args:
396
+ hparam_def: Serialized hyperparameters, encoded as a hparam_pb2.HParamDef
397
+ protocol buffer. If provided, this object is initialized by
398
+ deserializing hparam_def. Otherwise **kwargs is used.
399
+ model_structure: An instance of ModelStructure, defining the feature
400
+ crosses to be used in the Trial.
401
+ **kwargs: Key-value pairs where the key is the hyperparameter name and
402
+ the value is the value for the parameter.
403
+
404
+ Raises:
405
+ ValueError: If both `hparam_def` and initialization values are provided,
406
+ or if one of the arguments is invalid.
407
+
408
+ """
409
+ # Register the hyperparameters and their type in _hparam_types.
410
+ # This simplifies the implementation of parse().
411
+ # _hparam_types maps the parameter name to a tuple (type, bool).
412
+ # The type value is the type of the parameter for scalar hyperparameters,
413
+ # or the type of the list elements for multidimensional hyperparameters.
414
+ # The bool value is True if the value is a list, False otherwise.
415
+ self._hparam_types = {}
416
+ self._model_structure = model_structure
417
+ if hparam_def:
418
+ ## self._init_from_proto(hparam_def)
419
+ ## if kwargs:
420
+ ## raise ValueError('hparam_def and initialization values are '
421
+ ## 'mutually exclusive')
422
+ raise ValueError('hparam_def has been disabled in this version')
423
+ else:
424
+ for name, value in six.iteritems(kwargs):
425
+ self.add_hparam(name, value)
426
+
427
+ ## def _init_from_proto(self, hparam_def):
428
+ ## """Creates a new HParams from `HParamDef` protocol buffer.
429
+ ##
430
+ ## Args:
431
+ ## hparam_def: `HParamDef` protocol buffer.
432
+ ## """
433
+ ## assert isinstance(hparam_def, hparam_pb2.HParamDef)
434
+ ## for name, value in hparam_def.hparam.items():
435
+ ## kind = value.WhichOneof('kind')
436
+ ## if kind.endswith('_value'):
437
+ ## # Single value.
438
+ ## if kind.startswith('int64'):
439
+ ## # Setting attribute value to be 'int' to ensure the type is compatible
440
+ ## # with both Python2 and Python3.
441
+ ## self.add_hparam(name, int(getattr(value, kind)))
442
+ ## elif kind.startswith('bytes'):
443
+ ## # Setting attribute value to be 'str' to ensure the type is compatible
444
+ ## # with both Python2 and Python3. UTF-8 encoding is assumed.
445
+ ## self.add_hparam(name, compat.as_str(getattr(value, kind)))
446
+ ## else:
447
+ ## self.add_hparam(name, getattr(value, kind))
448
+ ## else:
449
+ ## # List of values.
450
+ ## if kind.startswith('int64'):
451
+ ## # Setting attribute value to be 'int' to ensure the type is compatible
452
+ ## # with both Python2 and Python3.
453
+ ## self.add_hparam(name, [int(v) for v in getattr(value, kind).value])
454
+ ## elif kind.startswith('bytes'):
455
+ ## # Setting attribute value to be 'str' to ensure the type is compatible
456
+ ## # with both Python2 and Python3. UTF-8 encoding is assumed.
457
+ ## self.add_hparam(
458
+ ## name, [compat.as_str(v) for v in getattr(value, kind).value])
459
+ ## else:
460
+ ## self.add_hparam(name, [v for v in getattr(value, kind).value])
461
+
462
+ def add_hparam(self, name, value):
463
+ """Adds {name, value} pair to hyperparameters.
464
+
465
+ Args:
466
+ name: Name of the hyperparameter.
467
+ value: Value of the hyperparameter. Can be one of the following types:
468
+ int, float, string, int list, float list, or string list.
469
+
470
+ Raises:
471
+ ValueError: if one of the arguments is invalid.
472
+ """
473
+ # Keys in kwargs are unique, but 'name' could the name of a pre-existing
474
+ # attribute of this object. In that case we refuse to use it as a
475
+ # hyperparameter name.
476
+ if getattr(self, name, None) is not None:
477
+ raise ValueError('Hyperparameter name is reserved: %s' % name)
478
+ if isinstance(value, (list, tuple)):
479
+ if not value:
480
+ raise ValueError(
481
+ 'Multi-valued hyperparameters cannot be empty: %s' % name)
482
+ self._hparam_types[name] = (type(value[0]), True)
483
+ else:
484
+ self._hparam_types[name] = (type(value), False)
485
+ setattr(self, name, value)
486
+
487
+ def set_hparam(self, name, value):
488
+ """Set the value of an existing hyperparameter.
489
+
490
+ This function verifies that the type of the value matches the type of the
491
+ existing hyperparameter.
492
+
493
+ Args:
494
+ name: Name of the hyperparameter.
495
+ value: New value of the hyperparameter.
496
+
497
+ Raises:
498
+ ValueError: If there is a type mismatch.
499
+ """
500
+ param_type, is_list = self._hparam_types[name]
501
+ if isinstance(value, list):
502
+ if not is_list:
503
+ raise ValueError(
504
+ 'Must not pass a list for single-valued parameter: %s' % name)
505
+ setattr(self, name, [
506
+ _cast_to_type_if_compatible(name, param_type, v) for v in value])
507
+ else:
508
+ if is_list:
509
+ raise ValueError(
510
+ 'Must pass a list for multi-valued parameter: %s.' % name)
511
+ setattr(self, name, _cast_to_type_if_compatible(name, param_type, value))
512
+
513
+ def del_hparam(self, name):
514
+ """Removes the hyperparameter with key 'name'.
515
+
516
+ Args:
517
+ name: Name of the hyperparameter.
518
+ """
519
+ if hasattr(self, name):
520
+ delattr(self, name)
521
+ del self._hparam_types[name]
522
+
523
+ def parse(self, values):
524
+ """Override hyperparameter values, parsing new values from a string.
525
+
526
+ See parse_values for more detail on the allowed format for values.
527
+
528
+ Args:
529
+ values: String. Comma separated list of `name=value` pairs where
530
+ 'value' must follow the syntax described above.
531
+
532
+ Returns:
533
+ The `HParams` instance.
534
+
535
+ Raises:
536
+ ValueError: If `values` cannot be parsed.
537
+ """
538
+ type_map = dict()
539
+ for name, t in self._hparam_types.items():
540
+ param_type, _ = t
541
+ type_map[name] = param_type
542
+
543
+ values_map = parse_values(values, type_map)
544
+ return self.override_from_dict(values_map)
545
+
546
+ def override_from_dict(self, values_dict):
547
+ """Override hyperparameter values, parsing new values from a dictionary.
548
+
549
+ Args:
550
+ values_dict: Dictionary of name:value pairs.
551
+
552
+ Returns:
553
+ The `HParams` instance.
554
+
555
+ Raises:
556
+ ValueError: If `values_dict` cannot be parsed.
557
+ """
558
+ for name, value in values_dict.items():
559
+ self.set_hparam(name, value)
560
+ return self
561
+
562
+ ## @deprecation.deprecated(None, 'Use `override_from_dict`.')
563
+ def set_from_map(self, values_map):
564
+ """DEPRECATED. Use override_from_dict."""
565
+ return self.override_from_dict(values_dict=values_map)
566
+
567
+ def set_model_structure(self, model_structure):
568
+ self._model_structure = model_structure
569
+
570
+ def get_model_structure(self):
571
+ return self._model_structure
572
+
573
+ def to_json(self, indent=None, separators=None, sort_keys=False):
574
+ """Serializes the hyperparameters into JSON.
575
+
576
+ Args:
577
+ indent: If a non-negative integer, JSON array elements and object members
578
+ will be pretty-printed with that indent level. An indent level of 0, or
579
+ negative, will only insert newlines. `None` (the default) selects the
580
+ most compact representation.
581
+ separators: Optional `(item_separator, key_separator)` tuple. Default is
582
+ `(', ', ': ')`.
583
+ sort_keys: If `True`, the output dictionaries will be sorted by key.
584
+
585
+ Returns:
586
+ A JSON string.
587
+ """
588
+ return json.dumps(
589
+ self.values(),
590
+ indent=indent,
591
+ separators=separators,
592
+ sort_keys=sort_keys)
593
+
594
+ def parse_json(self, values_json):
595
+ """Override hyperparameter values, parsing new values from a json object.
596
+
597
+ Args:
598
+ values_json: String containing a json object of name:value pairs.
599
+
600
+ Returns:
601
+ The `HParams` instance.
602
+
603
+ Raises:
604
+ ValueError: If `values_json` cannot be parsed.
605
+ """
606
+ values_map = json.loads(values_json)
607
+ return self.override_from_dict(values_map)
608
+
609
+ def values(self):
610
+ """Return the hyperparameter values as a Python dictionary.
611
+
612
+ Returns:
613
+ A dictionary with hyperparameter names as keys. The values are the
614
+ hyperparameter values.
615
+ """
616
+ return {n: getattr(self, n) for n in self._hparam_types.keys()}
617
+
618
+ def get(self, key, default=None):
619
+ """Returns the value of `key` if it exists, else `default`."""
620
+ if key in self._hparam_types:
621
+ # Ensure that default is compatible with the parameter type.
622
+ if default is not None:
623
+ param_type, is_param_list = self._hparam_types[key]
624
+ type_str = 'list<%s>' % param_type if is_param_list else str(param_type)
625
+ fail_msg = ("Hparam '%s' of type '%s' is incompatible with "
626
+ 'default=%s' % (key, type_str, default))
627
+
628
+ is_default_list = isinstance(default, list)
629
+ if is_param_list != is_default_list:
630
+ raise ValueError(fail_msg)
631
+
632
+ try:
633
+ if is_default_list:
634
+ for value in default:
635
+ _cast_to_type_if_compatible(key, param_type, value)
636
+ else:
637
+ _cast_to_type_if_compatible(key, param_type, default)
638
+ except ValueError as e:
639
+ raise ValueError('%s. %s' % (fail_msg, e))
640
+
641
+ return getattr(self, key)
642
+
643
+ return default
644
+
645
+ def __contains__(self, key):
646
+ return key in self._hparam_types
647
+
648
+ def __str__(self):
649
+ return str(sorted(self.values().items()))
650
+
651
+ def __repr__(self):
652
+ return '%s(%s)' % (type(self).__name__, self.__str__())
653
+
654
+ @staticmethod
655
+ def _get_kind_name(param_type, is_list):
656
+ """Returns the field name given parameter type and is_list.
657
+
658
+ Args:
659
+ param_type: Data type of the hparam.
660
+ is_list: Whether this is a list.
661
+
662
+ Returns:
663
+ A string representation of the field name.
664
+
665
+ Raises:
666
+ ValueError: If parameter type is not recognized.
667
+ """
668
+ if issubclass(param_type, bool):
669
+ # This check must happen before issubclass(param_type, six.integer_types),
670
+ # since Python considers bool to be a subclass of int.
671
+ typename = 'bool'
672
+ elif issubclass(param_type, six.integer_types):
673
+ # Setting 'int' and 'long' types to be 'int64' to ensure the type is
674
+ # compatible with both Python2 and Python3.
675
+ typename = 'int64'
676
+ elif issubclass(param_type, (six.string_types, six.binary_type)):
677
+ # Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is
678
+ # compatible with both Python2 and Python3.
679
+ typename = 'bytes'
680
+ elif issubclass(param_type, float):
681
+ typename = 'float'
682
+ else:
683
+ raise ValueError('Unsupported parameter type: %s' % str(param_type))
684
+
685
+ suffix = 'list' if is_list else 'value'
686
+ return '_'.join([typename, suffix])
687
+
688
+ ## def to_proto(self, export_scope=None): # pylint: disable=unused-argument
689
+ ## """Converts a `HParams` object to a `HParamDef` protocol buffer.
690
+ ##
691
+ ## Args:
692
+ ## export_scope: Optional `string`. Name scope to remove.
693
+ ##
694
+ ## Returns:
695
+ ## A `HParamDef` protocol buffer.
696
+ ## """
697
+ ## hparam_proto = hparam_pb2.HParamDef()
698
+ ## for name in self._hparam_types:
699
+ ## # Parse the values.
700
+ ## param_type, is_list = self._hparam_types.get(name, (None, None))
701
+ ## kind = HParams._get_kind_name(param_type, is_list)
702
+ ##
703
+ ## if is_list:
704
+ ## if kind.startswith('bytes'):
705
+ ## v_list = [compat.as_bytes(v) for v in getattr(self, name)]
706
+ ## else:
707
+ ## v_list = [v for v in getattr(self, name)]
708
+ ## getattr(hparam_proto.hparam[name], kind).value.extend(v_list)
709
+ ## else:
710
+ ## v = getattr(self, name)
711
+ ## if kind.startswith('bytes'):
712
+ ## v = compat.as_bytes(getattr(self, name))
713
+ ## setattr(hparam_proto.hparam[name], kind, v)
714
+ ##
715
+ ## return hparam_proto
716
+
717
+ ## @staticmethod
718
+ ## def from_proto(hparam_def, import_scope=None): # pylint: disable=unused-argument
719
+ ## return HParams(hparam_def=hparam_def)
720
+
721
+
722
+ ## ops.register_proto_function(
723
+ ## 'hparams',
724
+ ## proto_type=hparam_pb2.HParamDef,
725
+ ## to_proto=HParams.to_proto,
726
+ ## from_proto=HParams.from_proto)
tfcompat/readme.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ Source: hparam.py copied from tensorflow v1.12.0.
2
+
3
+ https://github.com/tensorflow/tensorflow/blob/v1.12.0/tensorflow/contrib/training/python/training/hparam.py
4
+
5
+ with the following:
6
+ wget https://github.com/tensorflow/tensorflow/raw/v1.12.0/tensorflow/contrib/training/python/training/hparam.py
7
+
8
+ Once all other tensorflow dependencies of these file are removed, the class keeps its goal. Functions not available due to this process are not used in this project.
utils.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ import numpy as np
4
+ from scipy import signal
5
+ from librosa.filters import mel
6
+ from scipy.signal import get_window
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def butter_highpass(cutoff, fs, order=5):
13
+ nyq = 0.5 * fs
14
+ normal_cutoff = cutoff / nyq
15
+ b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
16
+ return b, a
17
+
18
+
19
+
20
+ def pySTFT(x, fft_length=1024, hop_length=256):
21
+
22
+ x = np.pad(x, int(fft_length//2), mode='reflect')
23
+
24
+ noverlap = fft_length - hop_length
25
+ shape = x.shape[:-1]+((x.shape[-1]-noverlap)//hop_length, fft_length)
26
+ strides = x.strides[:-1]+(hop_length*x.strides[-1], x.strides[-1])
27
+ result = np.lib.stride_tricks.as_strided(x, shape=shape,
28
+ strides=strides)
29
+
30
+ fft_window = get_window('hann', fft_length, fftbins=True)
31
+ result = np.fft.rfft(fft_window * result, n=fft_length).T
32
+
33
+ return np.abs(result)
34
+
35
+
36
+
37
+ class LinearNorm(torch.nn.Module):
38
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
39
+ super(LinearNorm, self).__init__()
40
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
41
+
42
+ torch.nn.init.xavier_uniform_(
43
+ self.linear_layer.weight,
44
+ gain=torch.nn.init.calculate_gain(w_init_gain))
45
+
46
+ def forward(self, x):
47
+ return self.linear_layer(x)
48
+
49
+
50
+ class ConvNorm(torch.nn.Module):
51
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
52
+ padding=None, dilation=1, bias=True, w_init_gain='linear'):
53
+ super(ConvNorm, self).__init__()
54
+ if padding is None:
55
+ assert(kernel_size % 2 == 1)
56
+ padding = int(dilation * (kernel_size - 1) / 2)
57
+
58
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
59
+ kernel_size=kernel_size, stride=stride,
60
+ padding=padding, dilation=dilation,
61
+ bias=bias)
62
+
63
+ torch.nn.init.xavier_uniform_(
64
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
65
+
66
+ def forward(self, signal):
67
+ conv_signal = self.conv(signal)
68
+ return conv_signal
69
+
70
+
71
+
72
+ def filter_bank_mean(num_rep, codes_mask, max_len_long):
73
+ '''
74
+ num_rep (B, L)
75
+ codes_mask (B, L)
76
+
77
+ output: filterbank (B, L, max_len_fake)
78
+
79
+ zero pad in codes must be real zero
80
+ '''
81
+
82
+ num_rep = num_rep.unsqueeze(-1) # (B, L, 1)
83
+ codes_mask = codes_mask.unsqueeze(-1) # (B, L, 1)
84
+ num_rep = num_rep * codes_mask
85
+
86
+ right_edge = num_rep.cumsum(dim=1)
87
+ left_edge = torch.zeros_like(right_edge)
88
+ left_edge[:, 1:, :] = right_edge[:, :-1, :]
89
+ right_edge = right_edge.ceil()
90
+ left_edge = left_edge.floor()
91
+
92
+ index = torch.arange(1, max_len_long+1, device=num_rep.device).view(1, 1, -1)
93
+
94
+ lower = index - left_edge
95
+
96
+ right_edge_flip = max_len_long - right_edge
97
+
98
+ upper = (index - right_edge_flip).flip(dims=(2,))
99
+
100
+ # triangular pooling
101
+ fb = F.relu(torch.min(lower, upper)).float()
102
+
103
+ # mean pooling
104
+ fb = (fb > 0).float()
105
+
106
+ norm = fb.sum(dim=-1, keepdim=True)
107
+ norm[norm==0] = 1.0
108
+
109
+ fb = fb / norm
110
+
111
+ return fb * codes_mask