Spaces:
Build error
Build error
taras-sereda
commited on
Commit
·
96ee597
1
Parent(s):
694ecc6
minimal set of files to run inference; pheme-small checkpoint
Browse files- ckpt/s2a/config.json +69 -0
- ckpt/s2a/s2a.ckpt +3 -0
- ckpt/t2s/config.json +29 -0
- ckpt/t2s/generation_config.json +7 -0
- ckpt/t2s/pytorch_model.bin +3 -0
- constants.py +14 -0
- data/collation.py +182 -0
- data/data_module.py +119 -0
- data/sampler.py +115 -0
- data/semantic_dataset.py +207 -0
- data/single_speaker_dataset.py +167 -0
- modules/__init__.py +0 -0
- modules/conformer.py +671 -0
- modules/masking_logic.py +111 -0
- modules/s2a_model.py +563 -0
- modules/speech_tokenizer.py +86 -0
- modules/t2s_model.py +111 -0
- modules/tokenizer.py +73 -0
- modules/vocoder.py +79 -0
- transformer_infer.py +263 -0
- utils/__init__.py +73 -0
- utils/get_tokens_speech_tokenizer.py +70 -0
- utils/symbol_table.py +281 -0
ckpt/s2a/config.json
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"saving_path": "/home/ubuntu/experiments/a2s_giga2",
|
3 |
+
"resume_checkpoint": null,
|
4 |
+
"vocoder_type": "SPEECHTOKENIZER",
|
5 |
+
"vocoder_config_path": null,
|
6 |
+
"vocoder_ckpt_path": null,
|
7 |
+
"metapath": [
|
8 |
+
"/home/ubuntu/data/poly/giga-training-data/train.json"
|
9 |
+
],
|
10 |
+
"val_metapath": [
|
11 |
+
"/home/ubuntu/data/poly/giga-training-data/dev.json"
|
12 |
+
],
|
13 |
+
"pretrained_path": null,
|
14 |
+
"speaker_embedding_dir": null,
|
15 |
+
"sampledir": "/home/ubuntu/experiments/a2s_giga2",
|
16 |
+
"lr": 0.0005,
|
17 |
+
"batch_size": 400.0,
|
18 |
+
"train_bucket_size": 8192,
|
19 |
+
"training_step": 800000,
|
20 |
+
"optim_flat_percent": 0.0,
|
21 |
+
"warmup_step": 10000,
|
22 |
+
"adam_beta1": 0.9,
|
23 |
+
"adam_beta2": 0.98,
|
24 |
+
"ffd_size": 1024,
|
25 |
+
"hidden_size": 768,
|
26 |
+
"enc_nlayers": 3,
|
27 |
+
"dec_nlayers": 6,
|
28 |
+
"nheads": 8,
|
29 |
+
"dropout": 0.1,
|
30 |
+
"depthwise_conv_kernel_size": 5,
|
31 |
+
"aligner_softmax_temp": 1.0,
|
32 |
+
"layer_norm_eps": 1e-05,
|
33 |
+
"use_sem_tokens": true,
|
34 |
+
"use_spkr_emb": true,
|
35 |
+
"use_text_emb": false,
|
36 |
+
"fairseq": false,
|
37 |
+
"only_inference": false,
|
38 |
+
"speaker_embed_dropout": 0.05,
|
39 |
+
"label_smoothing": 0.0,
|
40 |
+
"val_check_interval": 1,
|
41 |
+
"max_dataset_samples": -1,
|
42 |
+
"check_val_every_n_epoch": 1,
|
43 |
+
"precision": "bf16",
|
44 |
+
"nworkers": 12,
|
45 |
+
"distributed": true,
|
46 |
+
"accelerator": "gpu",
|
47 |
+
"version": null,
|
48 |
+
"accumulate_grad_batches": 1,
|
49 |
+
"sagemaker": false,
|
50 |
+
"use_repetition_token": false,
|
51 |
+
"use_repetition_gating": false,
|
52 |
+
"repetition_penalty": 1.0,
|
53 |
+
"sampling_temperature": 1.0,
|
54 |
+
"top_k": -1,
|
55 |
+
"min_top_k": 3,
|
56 |
+
"top_p": 0.8,
|
57 |
+
"sample_num": 4,
|
58 |
+
"length_penalty_max_length": 150,
|
59 |
+
"length_penalty_max_prob": 0.95,
|
60 |
+
"max_input_length": 2048,
|
61 |
+
"max_output_length": 2000,
|
62 |
+
"phone_context_window": 3,
|
63 |
+
"sample_rate": 16000,
|
64 |
+
"n_codes": 1024,
|
65 |
+
"n_cluster_groups": 7,
|
66 |
+
"first_n_lvls": 7,
|
67 |
+
"use_pretrained_ckpt_cfg": false,
|
68 |
+
"n_semantic_codes": 1024
|
69 |
+
}
|
ckpt/s2a/s2a.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f083b82b7a2e902cf318163562f6b3faa2d2b4ce72d20b92e3de8b8dfc125383
|
3 |
+
size 671800067
|
ckpt/t2s/config.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"T5ForConditionalGeneration"
|
4 |
+
],
|
5 |
+
"classifier_dropout": 0.0,
|
6 |
+
"d_ff": 2048,
|
7 |
+
"d_kv": 64,
|
8 |
+
"d_model": 512,
|
9 |
+
"decoder_start_token_id": 0,
|
10 |
+
"dense_act_fn": "relu",
|
11 |
+
"dropout_rate": 0.1,
|
12 |
+
"eos_token_id": 2,
|
13 |
+
"feed_forward_proj": "relu",
|
14 |
+
"initializer_factor": 1.0,
|
15 |
+
"is_encoder_decoder": true,
|
16 |
+
"is_gated_act": false,
|
17 |
+
"layer_norm_epsilon": 1e-06,
|
18 |
+
"model_type": "t5",
|
19 |
+
"num_decoder_layers": 6,
|
20 |
+
"num_heads": 8,
|
21 |
+
"num_layers": 6,
|
22 |
+
"pad_token_id": 0,
|
23 |
+
"relative_attention_max_distance": 128,
|
24 |
+
"relative_attention_num_buckets": 32,
|
25 |
+
"torch_dtype": "float32",
|
26 |
+
"transformers_version": "4.34.1",
|
27 |
+
"use_cache": true,
|
28 |
+
"vocab_size": 1119
|
29 |
+
}
|
ckpt/t2s/generation_config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"decoder_start_token_id": 0,
|
4 |
+
"eos_token_id": 2,
|
5 |
+
"pad_token_id": 0,
|
6 |
+
"transformers_version": "4.34.1"
|
7 |
+
}
|
ckpt/t2s/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:00aab5d35aa54a7cc7c81c92c706928e33d82a5c7e63b3628df48b0aed28606f
|
3 |
+
size 178565209
|
constants.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Constants file.
|
2 |
+
|
3 |
+
Copyright PolyAI Limited.
|
4 |
+
"""
|
5 |
+
SPKR_EMB_SIZE = 512
|
6 |
+
|
7 |
+
PAD = 1024
|
8 |
+
|
9 |
+
SPKR_1 = 1025
|
10 |
+
SPKR_2 = 1026
|
11 |
+
|
12 |
+
BOS_TOKEN_ID = 0
|
13 |
+
PAD_TOKEN_ID = 0
|
14 |
+
EOS_TOKEN_ID = 2
|
data/collation.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Collators for T2S and S2A.
|
2 |
+
|
3 |
+
Copyright PolyAI Limited.
|
4 |
+
"""
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import List, Tuple, Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from utils.symbol_table import SymbolTable
|
12 |
+
|
13 |
+
|
14 |
+
class GlobalCollater:
|
15 |
+
def __init__(self, n_codes, n_semantic_codes):
|
16 |
+
self.n_codes = n_codes
|
17 |
+
self.sem_mask_id = n_semantic_codes
|
18 |
+
|
19 |
+
def collate(self, batch):
|
20 |
+
output = {
|
21 |
+
'speaker': [],
|
22 |
+
'tts_quantize_input': [],
|
23 |
+
'tts_quantize_output': [],
|
24 |
+
'quantize_mask': [],
|
25 |
+
'f_names': [],
|
26 |
+
'semantic_tokens': [],
|
27 |
+
'quantization_lengths': [],
|
28 |
+
}
|
29 |
+
# Get the max length of everything
|
30 |
+
max_len_q = 0
|
31 |
+
for _, q_s, q_e, _, _ in batch:
|
32 |
+
if len(q_s) > max_len_q:
|
33 |
+
max_len_q = len(q_s)
|
34 |
+
|
35 |
+
output['quantization_lengths'].append(len(q_s))
|
36 |
+
|
37 |
+
# Pad each element, create mask
|
38 |
+
for spkr, qs, qe, itm_name, s_tokens in batch:
|
39 |
+
# Deal with quantizations
|
40 |
+
q_mask = np.array(
|
41 |
+
[False] * len(qs) + [True] * (max_len_q - len(qs)))
|
42 |
+
qs = np.pad(
|
43 |
+
qs,
|
44 |
+
[[0, max_len_q-len(qs)], [0, 0]],
|
45 |
+
constant_values=self.n_codes
|
46 |
+
)
|
47 |
+
qe = np.pad(
|
48 |
+
qe,
|
49 |
+
[[0, max_len_q-len(qe)], [0, 0]],
|
50 |
+
constant_values=self.n_codes
|
51 |
+
)
|
52 |
+
|
53 |
+
# Deal with semantics
|
54 |
+
s_tokens = s_tokens.flatten()
|
55 |
+
s_tokens = np.pad(
|
56 |
+
s_tokens,
|
57 |
+
(0, max_len_q-len(s_tokens)),
|
58 |
+
constant_values=self.sem_mask_id
|
59 |
+
)
|
60 |
+
|
61 |
+
# Speaker padding
|
62 |
+
spkr = np.concatenate(
|
63 |
+
(spkr, np.zeros((max_len_q - len(spkr), 512))))
|
64 |
+
|
65 |
+
# Aggregate
|
66 |
+
output['speaker'].append(spkr)
|
67 |
+
output['tts_quantize_input'].append(qs)
|
68 |
+
output['tts_quantize_output'].append(qe)
|
69 |
+
output['quantize_mask'].append(q_mask)
|
70 |
+
output['f_names'].append(itm_name)
|
71 |
+
output["semantic_tokens"].append(s_tokens)
|
72 |
+
|
73 |
+
for k in output.keys():
|
74 |
+
if k == 'f_names':
|
75 |
+
continue
|
76 |
+
output[k] = np.array(output[k])
|
77 |
+
if 'mask' in k:
|
78 |
+
output[k] = torch.BoolTensor(output[k])
|
79 |
+
elif k in [
|
80 |
+
'tts_quantize_input', 'tts_quantize_output',
|
81 |
+
'semantic_tokens', 'quantization_lengths'
|
82 |
+
]:
|
83 |
+
output[k] = torch.LongTensor(output[k])
|
84 |
+
else:
|
85 |
+
output[k] = torch.FloatTensor(output[k])
|
86 |
+
return output
|
87 |
+
|
88 |
+
|
89 |
+
class TextTokenCollater:
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
text_tokens: List[str],
|
93 |
+
add_eos: bool = True,
|
94 |
+
add_bos: bool = True,
|
95 |
+
pad_symbol: str = "<pad>",
|
96 |
+
bos_symbol: str = "<bos>",
|
97 |
+
eos_symbol: str = "<eos>",
|
98 |
+
spkr_1_symbol: str = "spkr_1",
|
99 |
+
spkr_2_symbol: str = "spkr_2",
|
100 |
+
):
|
101 |
+
self.pad_symbol = pad_symbol
|
102 |
+
|
103 |
+
self.add_eos = add_eos
|
104 |
+
self.add_bos = add_bos
|
105 |
+
|
106 |
+
self.bos_symbol = bos_symbol
|
107 |
+
self.eos_symbol = eos_symbol
|
108 |
+
self.spkr_1_symbol = spkr_1_symbol
|
109 |
+
self.spkr_2_symbol = spkr_2_symbol
|
110 |
+
|
111 |
+
unique_tokens = (
|
112 |
+
[pad_symbol]
|
113 |
+
+ ([bos_symbol] if add_bos else [])
|
114 |
+
+ ([eos_symbol] if add_eos else [])
|
115 |
+
+ ([spkr_1_symbol])
|
116 |
+
+ ([spkr_2_symbol])
|
117 |
+
+ sorted(text_tokens)
|
118 |
+
)
|
119 |
+
|
120 |
+
self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
|
121 |
+
self.idx2token = [token for token in unique_tokens]
|
122 |
+
|
123 |
+
def __call__(
|
124 |
+
self, texts: List[str], texts_2: Union[None, List[str]] = None
|
125 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
126 |
+
tokens_seqs = [[p for p in text] for text in texts]
|
127 |
+
|
128 |
+
if texts_2 is None:
|
129 |
+
seqs = [
|
130 |
+
([self.bos_symbol] if self.add_bos else [])
|
131 |
+
+ [self.spkr_1_symbol]
|
132 |
+
+ list(seq)
|
133 |
+
+ ([self.eos_symbol] if self.add_eos else [])
|
134 |
+
for seq in tokens_seqs
|
135 |
+
]
|
136 |
+
else:
|
137 |
+
tokens_seqs_2 = [[p for p in text] for text in texts_2]
|
138 |
+
seqs = [
|
139 |
+
([self.bos_symbol] if self.add_bos else [])
|
140 |
+
+ [self.spkr_1_symbol]
|
141 |
+
+ list(seq)
|
142 |
+
+ ([self.spkr_2_symbol])
|
143 |
+
+ list(seq_2)
|
144 |
+
+ ([self.eos_symbol] if self.add_eos else [])
|
145 |
+
for seq, seq_2 in zip(tokens_seqs, tokens_seqs_2)
|
146 |
+
]
|
147 |
+
|
148 |
+
tokens_batch = torch.from_numpy(
|
149 |
+
np.array(
|
150 |
+
[[self.token2idx[token] for token in seq] for seq in seqs],
|
151 |
+
dtype=np.int64,
|
152 |
+
)
|
153 |
+
)
|
154 |
+
|
155 |
+
return tokens_batch
|
156 |
+
|
157 |
+
|
158 |
+
def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater:
|
159 |
+
text_tokens_path = Path(text_tokens_file)
|
160 |
+
unique_tokens = SymbolTable.from_file(text_tokens_path)
|
161 |
+
collater = TextTokenCollater(
|
162 |
+
unique_tokens.symbols, add_bos=True, add_eos=True
|
163 |
+
)
|
164 |
+
return collater
|
165 |
+
|
166 |
+
|
167 |
+
def get_text_semantic_token_collater(
|
168 |
+
text_tokens_file: str, n_semantic_tokens=1024) -> TextTokenCollater:
|
169 |
+
text_tokens_path = Path(text_tokens_file)
|
170 |
+
unique_tokens = SymbolTable.from_file(text_tokens_path)
|
171 |
+
for semantic_idx in range(n_semantic_tokens):
|
172 |
+
unique_tokens.add(str(semantic_idx))
|
173 |
+
|
174 |
+
collater = TextTokenCollater(
|
175 |
+
unique_tokens.symbols, add_bos=True, add_eos=True
|
176 |
+
)
|
177 |
+
return collater
|
178 |
+
|
179 |
+
|
180 |
+
if __name__ == '__main__':
|
181 |
+
text_tokens_file = 'ckpt/unique_text_tokens.k2symbols'
|
182 |
+
collater = get_text_semantic_token_collater(text_tokens_file)
|
data/data_module.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Data module.
|
2 |
+
|
3 |
+
Copyright PolyAI Limited.
|
4 |
+
"""
|
5 |
+
import typing
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
import lightning.pytorch as pl
|
10 |
+
from torch.utils import data
|
11 |
+
|
12 |
+
from data.collation import GlobalCollater
|
13 |
+
from data.sampler import RandomBucketSampler
|
14 |
+
from data.single_speaker_dataset import QuantizeDataset
|
15 |
+
from utils import breakpoint_on_error
|
16 |
+
|
17 |
+
|
18 |
+
class ConcatDataset(data.ConcatDataset):
|
19 |
+
def __init__(self, datasets) -> None:
|
20 |
+
super().__init__(datasets)
|
21 |
+
self.lengths = []
|
22 |
+
for dataset in datasets:
|
23 |
+
self.lengths.extend(dataset.lengths)
|
24 |
+
|
25 |
+
|
26 |
+
class DataModule(pl.LightningDataModule):
|
27 |
+
def __init__(
|
28 |
+
self, hp, metapath: List[str], val_metapath: List[str],
|
29 |
+
world_size, local_rank
|
30 |
+
):
|
31 |
+
super().__init__()
|
32 |
+
self.hp = hp
|
33 |
+
self.metapath = metapath
|
34 |
+
self.val_metapath = val_metapath
|
35 |
+
self.world_size = world_size
|
36 |
+
self.local_rank = local_rank
|
37 |
+
self.collater = GlobalCollater(
|
38 |
+
self.hp.n_codes, self.hp.n_semantic_codes)
|
39 |
+
|
40 |
+
def setup(self, stage: str) -> None:
|
41 |
+
if stage == "fit":
|
42 |
+
self.train_data = self.concatenate_datasets(
|
43 |
+
self.metapath, dataset_class=QuantizeDataset
|
44 |
+
)
|
45 |
+
|
46 |
+
if stage == "valid":
|
47 |
+
self.val_data = []
|
48 |
+
self.val_data_keys = []
|
49 |
+
self.prepare_val_datasets()
|
50 |
+
assert len(self.val_data) > 0
|
51 |
+
assert len(self.val_data_keys) > 0
|
52 |
+
|
53 |
+
@breakpoint_on_error
|
54 |
+
def concatenate_datasets(
|
55 |
+
self, metapaths, dataset_class: typing.Type[QuantizeDataset]):
|
56 |
+
data = []
|
57 |
+
for _, metapath in enumerate(metapaths):
|
58 |
+
metapath = Path(metapath)
|
59 |
+
# assumption that audios and audios-embeddings
|
60 |
+
# are in the same folder as metapath
|
61 |
+
datadir = metapath.with_name("audios")
|
62 |
+
assert datadir.exists()
|
63 |
+
data.append(
|
64 |
+
dataset_class(
|
65 |
+
self.hp,
|
66 |
+
metapath,
|
67 |
+
datadir=datadir,
|
68 |
+
speaker_embedding_dir=None,
|
69 |
+
)
|
70 |
+
)
|
71 |
+
return ConcatDataset(data)
|
72 |
+
|
73 |
+
def prepare_val_datasets(self):
|
74 |
+
for manifest in self.val_metapath:
|
75 |
+
self.val_data.append(
|
76 |
+
self.concatenate_datasets(
|
77 |
+
[manifest], dataset_class=QuantizeDataset)
|
78 |
+
)
|
79 |
+
name = Path(manifest).parent.name
|
80 |
+
self.val_data_keys.append(name)
|
81 |
+
|
82 |
+
assert len(self.val_data) == len(self.val_data_keys)
|
83 |
+
|
84 |
+
def train_dataloader(self):
|
85 |
+
length = self.train_data.lengths
|
86 |
+
sampler = RandomBucketSampler(
|
87 |
+
self.hp.train_bucket_size,
|
88 |
+
length,
|
89 |
+
self.hp.batch_size,
|
90 |
+
drop_last=True,
|
91 |
+
distributed=self.hp.distributed,
|
92 |
+
world_size=self.world_size,
|
93 |
+
rank=self.local_rank,
|
94 |
+
)
|
95 |
+
dataloader = data.DataLoader(
|
96 |
+
self.train_data,
|
97 |
+
num_workers=self.hp.nworkers,
|
98 |
+
batch_sampler=sampler,
|
99 |
+
collate_fn=self.collater.collate,
|
100 |
+
pin_memory=True
|
101 |
+
)
|
102 |
+
|
103 |
+
return dataloader
|
104 |
+
|
105 |
+
def val_dataloader(self):
|
106 |
+
val_loaders = []
|
107 |
+
for dataset in self.val_data:
|
108 |
+
val_loaders.append(
|
109 |
+
data.DataLoader(
|
110 |
+
dataset,
|
111 |
+
num_workers=self.hp.nworkers,
|
112 |
+
batch_size=int(self.hp.batch_size),
|
113 |
+
collate_fn=self.collater.collate,
|
114 |
+
shuffle=False,
|
115 |
+
pin_memory=True
|
116 |
+
)
|
117 |
+
)
|
118 |
+
|
119 |
+
return val_loaders
|
data/sampler.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Original sampling logic of MQTTS.
|
2 |
+
|
3 |
+
Copyright PolyAI Limited.
|
4 |
+
"""
|
5 |
+
import math
|
6 |
+
import random
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from torch.utils import data
|
10 |
+
|
11 |
+
|
12 |
+
def StandardSampler(dataset, shuffle, distributed=False,
|
13 |
+
world_size=None, rank=None):
|
14 |
+
if distributed:
|
15 |
+
return data.distributed.DistributedSampler(
|
16 |
+
dataset, shuffle=shuffle, num_replicas=world_size, rank=rank)
|
17 |
+
if shuffle:
|
18 |
+
return data.RandomSampler(dataset)
|
19 |
+
return data.SequentialSampler(dataset)
|
20 |
+
|
21 |
+
|
22 |
+
def RandomBucketSampler(
|
23 |
+
nbuckets, length, batch_size, drop_last, distributed=False,
|
24 |
+
world_size=None, rank=None):
|
25 |
+
if distributed:
|
26 |
+
return DistributedRandomBucketSampler(
|
27 |
+
nbuckets, length, batch_size, drop_last, world_size, rank)
|
28 |
+
return SingleRandomBucketSampler(nbuckets, length, batch_size, drop_last)
|
29 |
+
|
30 |
+
|
31 |
+
class SingleRandomBucketSampler(data.Sampler):
|
32 |
+
def __init__(self, nbuckets, length, batch_size, drop_last):
|
33 |
+
self.length = length
|
34 |
+
self.batch_size = batch_size
|
35 |
+
self.drop_last = drop_last
|
36 |
+
indices = np.argsort([-x for x in length])
|
37 |
+
split = len(indices) // nbuckets
|
38 |
+
self.indices = []
|
39 |
+
for i in range(nbuckets):
|
40 |
+
self.indices.append(indices[i*split:(i+1)*split])
|
41 |
+
if nbuckets * split < len(length):
|
42 |
+
self.indices.append(indices[nbuckets*split:])
|
43 |
+
|
44 |
+
def __iter__(self):
|
45 |
+
random.shuffle(self.indices)
|
46 |
+
for x in self.indices:
|
47 |
+
random.shuffle(x)
|
48 |
+
idxs = [i for x in self.indices for i in x]
|
49 |
+
batches, batch, sum_len, max_len = [], [], 0, 0
|
50 |
+
for idx in idxs:
|
51 |
+
batch.append(idx)
|
52 |
+
sum_len += self.length[idx]
|
53 |
+
max_len = max(self.length[idx], max_len)
|
54 |
+
if max_len * len(batch) > self.batch_size:
|
55 |
+
batches.append(batch[:-1])
|
56 |
+
batch, sum_len, max_len = [batch[-1]], self.length[idx], self.length[idx] # noqa
|
57 |
+
if len(batch) > 0 and not self.drop_last:
|
58 |
+
batches.append(batch)
|
59 |
+
random.shuffle(batches)
|
60 |
+
return iter(batches)
|
61 |
+
|
62 |
+
|
63 |
+
class DistributedRandomBucketSampler(data.Sampler):
|
64 |
+
def __init__(self, nbuckets, length, batch_size,
|
65 |
+
drop_last, num_replicas, rank, seed=1234):
|
66 |
+
if rank >= num_replicas or rank < 0:
|
67 |
+
raise ValueError(
|
68 |
+
"Invalid rank {}, rank should be in the interval"
|
69 |
+
" [0, {}]".format(rank, num_replicas - 1))
|
70 |
+
indices = np.argsort(length)
|
71 |
+
split = len(indices) // nbuckets
|
72 |
+
self.length = length
|
73 |
+
self.batch_size = batch_size
|
74 |
+
self.drop_last = drop_last
|
75 |
+
self.indices = []
|
76 |
+
for i in range(nbuckets):
|
77 |
+
self.indices.append(indices[i*split:(i+1)*split])
|
78 |
+
if nbuckets * split < len(length):
|
79 |
+
self.indices.append(indices[nbuckets*split:])
|
80 |
+
self.num_replicas = num_replicas
|
81 |
+
self.rank = rank
|
82 |
+
self.epoch = 0
|
83 |
+
self.seed = seed
|
84 |
+
|
85 |
+
def __iter__(self):
|
86 |
+
# Deterministic shuffling
|
87 |
+
random.Random(self.epoch + self.seed).shuffle(self.indices)
|
88 |
+
for i, x in enumerate(self.indices):
|
89 |
+
seed = self.epoch + self.seed + i * 5
|
90 |
+
random.Random(seed).shuffle(x)
|
91 |
+
indices = [i for x in self.indices for i in x]
|
92 |
+
|
93 |
+
# Batching
|
94 |
+
batches, batch, sum_len, max_len = [], [], 0, 0
|
95 |
+
for idx in indices:
|
96 |
+
batch.append(idx)
|
97 |
+
sum_len += self.length[idx]
|
98 |
+
max_len = max(self.length[idx], max_len)
|
99 |
+
if max_len * len(batch) > self.batch_size:
|
100 |
+
batches.append(batch[:-1])
|
101 |
+
batch, sum_len, max_len = [batch[-1]], self.length[idx], self.length[idx] # noqa
|
102 |
+
# Subsample
|
103 |
+
num_samples = math.ceil(
|
104 |
+
(len(batches) - self.num_replicas) / self.num_replicas)
|
105 |
+
total_size = num_samples * self.num_replicas
|
106 |
+
batches = batches[:total_size]
|
107 |
+
batches = batches[self.rank*num_samples: (self.rank+1)*num_samples]
|
108 |
+
assert len(batches) == num_samples
|
109 |
+
|
110 |
+
# Stochastic suffling
|
111 |
+
random.shuffle(batches)
|
112 |
+
return iter(batches)
|
113 |
+
|
114 |
+
def set_epoch(self, epoch):
|
115 |
+
self.epoch = epoch
|
data/semantic_dataset.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Semantic tokens loading logic.
|
2 |
+
|
3 |
+
Copyright PolyAI Limited.
|
4 |
+
"""
|
5 |
+
import json
|
6 |
+
import logging
|
7 |
+
import random
|
8 |
+
import re
|
9 |
+
from logging import getLogger
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import List, Pattern, Union
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
from phonemizer.backend import EspeakBackend
|
16 |
+
from phonemizer.backend.espeak.language_switch import LanguageSwitch
|
17 |
+
from phonemizer.backend.espeak.words_mismatch import WordMismatch
|
18 |
+
from phonemizer.punctuation import Punctuation
|
19 |
+
from phonemizer.separator import Separator
|
20 |
+
from torch.utils.data import DataLoader, Dataset
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
from data.collation import get_text_semantic_token_collater
|
24 |
+
|
25 |
+
|
26 |
+
class TextTokenizer:
|
27 |
+
"""Phonemize Text."""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
language="en-us",
|
32 |
+
backend="espeak",
|
33 |
+
separator=Separator(word="_", syllable="-", phone="|"),
|
34 |
+
preserve_punctuation=True,
|
35 |
+
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
|
36 |
+
with_stress: bool = False,
|
37 |
+
tie: Union[bool, str] = False,
|
38 |
+
language_switch: LanguageSwitch = "keep-flags",
|
39 |
+
words_mismatch: WordMismatch = "ignore",
|
40 |
+
) -> None:
|
41 |
+
logger = getLogger("phonemizer")
|
42 |
+
logger.setLevel(logging.ERROR)
|
43 |
+
if backend == "espeak":
|
44 |
+
phonemizer = EspeakBackend(
|
45 |
+
language,
|
46 |
+
punctuation_marks=punctuation_marks,
|
47 |
+
preserve_punctuation=preserve_punctuation,
|
48 |
+
with_stress=with_stress,
|
49 |
+
tie=tie,
|
50 |
+
language_switch=language_switch,
|
51 |
+
words_mismatch=words_mismatch,
|
52 |
+
logger=logger,
|
53 |
+
)
|
54 |
+
else:
|
55 |
+
raise NotImplementedError(f"{backend}")
|
56 |
+
|
57 |
+
self.backend = phonemizer
|
58 |
+
self.separator = separator
|
59 |
+
|
60 |
+
def to_list(self, phonemized: str) -> List[str]:
|
61 |
+
fields = []
|
62 |
+
for word in phonemized.split(self.separator.word):
|
63 |
+
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
|
64 |
+
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
|
65 |
+
fields.extend(
|
66 |
+
[p for p in pp if p != self.separator.phone] + [self.separator.word]
|
67 |
+
)
|
68 |
+
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
|
69 |
+
self.separator.phone
|
70 |
+
)
|
71 |
+
return fields[:-1]
|
72 |
+
|
73 |
+
def __call__(self, text, strip=True) -> List[List[str]]:
|
74 |
+
if isinstance(text, str):
|
75 |
+
text = [text]
|
76 |
+
|
77 |
+
phonemized = self.backend.phonemize(
|
78 |
+
text, separator=self.separator, strip=strip, njobs=1
|
79 |
+
)
|
80 |
+
return [self.to_list(p) for p in phonemized]
|
81 |
+
|
82 |
+
|
83 |
+
class Collator:
|
84 |
+
def collate(self, batch):
|
85 |
+
input_ids = [item["input_ids"] for item in batch]
|
86 |
+
output_sequences = [item["labels"] for item in batch]
|
87 |
+
|
88 |
+
# Pad sequences to the maximum length in the batch
|
89 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
90 |
+
input_ids, batch_first=True, padding_value=0
|
91 |
+
)
|
92 |
+
output_sequences = torch.nn.utils.rnn.pad_sequence(
|
93 |
+
output_sequences, batch_first=True, padding_value=-100
|
94 |
+
)
|
95 |
+
# 1 - token is unmasked, 0 - token is masked.
|
96 |
+
attention_mask = input_ids != 0
|
97 |
+
return {
|
98 |
+
"input_ids": input_ids,
|
99 |
+
"attention_mask": attention_mask,
|
100 |
+
"labels": output_sequences,
|
101 |
+
}
|
102 |
+
|
103 |
+
class ConcatenateSemanticDataset(Dataset):
|
104 |
+
def __init__(
|
105 |
+
self, manifest_path: str, symbol_table_path: str,
|
106 |
+
n_samples: int = 0, max_duration=15):
|
107 |
+
self.data = []
|
108 |
+
self.phonemizer = TextTokenizer()
|
109 |
+
self.text_collater = get_text_semantic_token_collater(
|
110 |
+
symbol_table_path)
|
111 |
+
self.manifest_path = manifest_path
|
112 |
+
self.n_samples = n_samples
|
113 |
+
self.max_duration = max_duration
|
114 |
+
if manifest_path is not None:
|
115 |
+
self._build()
|
116 |
+
|
117 |
+
def __len__(self):
|
118 |
+
if self.n_samples:
|
119 |
+
return min(self.n_samples, len(self.data))
|
120 |
+
return len(self.data)
|
121 |
+
|
122 |
+
def remove_unknown_symbols(self, text: List[str]):
|
123 |
+
res = []
|
124 |
+
for sym in text:
|
125 |
+
if sym not in self.text_collater.token2idx:
|
126 |
+
# print(f'{sym} is unk')
|
127 |
+
continue
|
128 |
+
res.append(sym)
|
129 |
+
return res
|
130 |
+
|
131 |
+
def __getitem__(self, idx):
|
132 |
+
item = self.data[idx]
|
133 |
+
|
134 |
+
input_ids = item["phoneme"].split("|")
|
135 |
+
input_ids = self.remove_unknown_symbols(input_ids)
|
136 |
+
|
137 |
+
input_ids_2 = None
|
138 |
+
if item.get("phoneme_2"):
|
139 |
+
input_ids_2 = item["phoneme_2"].split("|")
|
140 |
+
input_ids_2 = [self.remove_unknown_symbols(input_ids_2)]
|
141 |
+
|
142 |
+
input_ids = self.text_collater(
|
143 |
+
[input_ids], input_ids_2).to(dtype=torch.long)
|
144 |
+
input_ids = input_ids.to(dtype=torch.long)
|
145 |
+
|
146 |
+
labels = np.load(item["semantic_path"])
|
147 |
+
labels = [str(lbl) for lbl in labels]
|
148 |
+
|
149 |
+
labels_2 = None
|
150 |
+
if item.get("semantic_path_2"):
|
151 |
+
labels_2 = np.load(item["semantic_path_2"])
|
152 |
+
labels_2 = [[str(lbl) for lbl in labels_2]]
|
153 |
+
|
154 |
+
labels = self.text_collater([labels], labels_2).to(dtype=torch.long)
|
155 |
+
|
156 |
+
return {"input_ids": input_ids.squeeze(0), "labels": labels.squeeze(0)}
|
157 |
+
|
158 |
+
# TODO - remove this to not load to the memory
|
159 |
+
def _build(self):
|
160 |
+
for manifest_path in self.manifest_path:
|
161 |
+
dataset_path = Path(manifest_path).parent
|
162 |
+
|
163 |
+
with open(manifest_path, "r") as manifest_file:
|
164 |
+
manifest_data = json.load(manifest_file)
|
165 |
+
|
166 |
+
for key, value in tqdm(manifest_data.items()):
|
167 |
+
if float(value["duration"]) > self.max_duration:
|
168 |
+
continue
|
169 |
+
text = value["text"]
|
170 |
+
phoneme = value["phoneme"]
|
171 |
+
npy_path = f"{dataset_path}/audios-speech-tokenizer/semantic/{key.split('.wav')[0]}.npy" # noqa
|
172 |
+
datapoint = {
|
173 |
+
"text": text,
|
174 |
+
"semantic_path": npy_path,
|
175 |
+
"phoneme": phoneme
|
176 |
+
}
|
177 |
+
self.data.append(datapoint)
|
178 |
+
|
179 |
+
print(f"Total length of the dataset {manifest_path}: {len(self.data)}")
|
180 |
+
|
181 |
+
random.shuffle(self.data)
|
182 |
+
|
183 |
+
|
184 |
+
if __name__ == "__main__":
|
185 |
+
# Create an instance of the dataset
|
186 |
+
manifest_path = "datasets/ljspeech-training-data/dev.json"
|
187 |
+
text_tokens_file = "ckpt/unique_text_tokens.k2symbols"
|
188 |
+
seq2seq_dataset = ConcatenateSemanticDataset(
|
189 |
+
[manifest_path, manifest_path], text_tokens_file)
|
190 |
+
|
191 |
+
# seq2seq_dataset.phonemize_and_rewrite_manifest()
|
192 |
+
batch_size = 1 # Adjust to your desired batch size
|
193 |
+
dataloader = DataLoader(
|
194 |
+
seq2seq_dataset,
|
195 |
+
batch_size=batch_size,
|
196 |
+
shuffle=True,
|
197 |
+
collate_fn=Collator().collate,
|
198 |
+
)
|
199 |
+
|
200 |
+
for batch in dataloader:
|
201 |
+
print(batch["input_ids"])
|
202 |
+
print(batch["labels"])
|
203 |
+
print(batch["input_ids"][0].unique().max())
|
204 |
+
print(batch["input_ids"][0].unique().min())
|
205 |
+
print(batch["input_ids"].shape)
|
206 |
+
print(batch["labels"].shape)
|
207 |
+
break # Stop after the first batch if needed
|
data/single_speaker_dataset.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Main loading function.
|
2 |
+
|
3 |
+
Copyright PolyAI Limited.
|
4 |
+
"""
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import soundfile as sf
|
12 |
+
import torch
|
13 |
+
from librosa.util import normalize
|
14 |
+
from pyannote.audio import Inference
|
15 |
+
from torch.utils import data
|
16 |
+
|
17 |
+
import constants as c
|
18 |
+
|
19 |
+
|
20 |
+
def random_crop(x, maxseqlen):
|
21 |
+
if x.shape[0] >= maxseqlen:
|
22 |
+
offset = random.randrange(x.shape[0] - maxseqlen + 1)
|
23 |
+
x = x[offset: offset + maxseqlen]
|
24 |
+
else:
|
25 |
+
offset = 0
|
26 |
+
return x, offset
|
27 |
+
|
28 |
+
|
29 |
+
def dynamic_range_compression(x, C=0.3, M=6.5, clip_val=1e-5):
|
30 |
+
return (np.log(np.clip(x, a_min=clip_val, a_max=None)) + M) * C
|
31 |
+
|
32 |
+
|
33 |
+
def dynamic_range_decompression(x, C=0.3, M=6.5):
|
34 |
+
return np.exp(x / C - M)
|
35 |
+
|
36 |
+
|
37 |
+
class QuantizeDataset(data.Dataset):
|
38 |
+
def __init__(self, hp, metapath, datadir=None, speaker_embedding_dir=None):
|
39 |
+
self.hp = hp
|
40 |
+
self.datadir = Path(datadir)
|
41 |
+
self.speaker_embedding_dir = speaker_embedding_dir
|
42 |
+
self.sem_mask_id = hp.n_semantic_codes
|
43 |
+
|
44 |
+
print(f"Loading metadata in {metapath}...")
|
45 |
+
with open(metapath, "r") as f:
|
46 |
+
self.text = json.load(f)
|
47 |
+
if 0 < self.hp.max_dataset_samples < len(self.text):
|
48 |
+
self.new_text = {}
|
49 |
+
num = 0
|
50 |
+
for k, v in self.text.items():
|
51 |
+
if num >= self.hp.max_dataset_samples:
|
52 |
+
break
|
53 |
+
self.new_text[k] = v
|
54 |
+
num += 1
|
55 |
+
self.text = self.new_text
|
56 |
+
|
57 |
+
self.datasetbase = [x for x in self.text.keys()]
|
58 |
+
self.dataset = [
|
59 |
+
os.path.join(self.datadir, x) for x in self.datasetbase]
|
60 |
+
|
61 |
+
if self.speaker_embedding_dir is None:
|
62 |
+
self.spkr_embedding = Inference(
|
63 |
+
"pyannote/embedding",
|
64 |
+
window="whole",
|
65 |
+
use_auth_token=os.environ["HUGGING_FACE_HUB_TOKEN"],
|
66 |
+
)
|
67 |
+
|
68 |
+
# Print statistics:
|
69 |
+
n = len(self.dataset)
|
70 |
+
print(f"Total {n} examples")
|
71 |
+
|
72 |
+
self.lengths = [float(v["duration"]) for v in self.text.values()]
|
73 |
+
total_duration = sum(self.lengths)
|
74 |
+
avglen = total_duration / len(self.lengths)
|
75 |
+
maxlen = max(self.lengths)
|
76 |
+
minlen = min(self.lengths)
|
77 |
+
print(
|
78 |
+
f"Average duration of audio: {avglen} sec, "
|
79 |
+
"Maximum duration: {maxlen} sec, Minimum duration: {minlen} sec"
|
80 |
+
)
|
81 |
+
|
82 |
+
def __len__(self):
|
83 |
+
return len(self.dataset)
|
84 |
+
|
85 |
+
def load_quantization(self, _name):
|
86 |
+
if self.hp.vocoder_type == 'NATIVE':
|
87 |
+
metadata = self.text[_name]
|
88 |
+
quantization = np.array(metadata["quantization"]).T # ..., 4
|
89 |
+
elif self.hp.vocoder_type == 'DAC':
|
90 |
+
codes_path = self.datadir.parent / 'audios-dac' / (os.path.splitext(_name)[0] + ".npy") # noqa
|
91 |
+
quantization = np.load(codes_path).T # ..., 12
|
92 |
+
elif self.hp.vocoder_type == 'ENCODEC':
|
93 |
+
codes_path = self.datadir.parent / 'audios-encodec' / (os.path.splitext(_name)[0] + ".npy") # noqa
|
94 |
+
quantization = np.load(codes_path).squeeze(0).T # ..., 8
|
95 |
+
elif self.hp.vocoder_type == 'SPEECHTOKENIZER':
|
96 |
+
codes_path = self.datadir.parent / 'audios-speech-tokenizer/acoustic' / (os.path.splitext(_name)[0] + ".npy") # noqa
|
97 |
+
quantization = np.load(codes_path).T # ..., 7
|
98 |
+
else:
|
99 |
+
raise ValueError(f"Unknown vocoder_type {self.hp.vocoder_type}")
|
100 |
+
|
101 |
+
return quantization
|
102 |
+
|
103 |
+
def __getitem__(self, i):
|
104 |
+
dataname = self.dataset[i]
|
105 |
+
_name = self.datasetbase[i]
|
106 |
+
metadata = self.text[_name]
|
107 |
+
|
108 |
+
# Speaker 1
|
109 |
+
acoustic_tokens = self.load_quantization(_name)
|
110 |
+
acoustic_tokens = np.pad(
|
111 |
+
acoustic_tokens, [[1, 0],[0,0]], constant_values=c.SPKR_1)
|
112 |
+
|
113 |
+
npy_path = self.datadir.parent / 'audios-speech-tokenizer/semantic' / (os.path.splitext(_name)[0] + ".npy") # noqa
|
114 |
+
semantic_tokens = np.load(npy_path)[None]
|
115 |
+
semantic_tokens = np.pad(
|
116 |
+
semantic_tokens,[[0,0], [1, 0]], constant_values=c.SPKR_1)
|
117 |
+
|
118 |
+
if "name_2" in metadata:
|
119 |
+
wav, _ = sf.read(dataname.split(".")[0] + "_1.wav")
|
120 |
+
else:
|
121 |
+
wav, _ = sf.read(dataname)
|
122 |
+
audio = normalize(wav) * 0.95
|
123 |
+
speaker_embedding = self.spkr_embedding(
|
124 |
+
{"waveform": torch.FloatTensor(audio).unsqueeze(0),
|
125 |
+
"sample_rate": self.hp.sample_rate,}
|
126 |
+
).reshape(1, -1)
|
127 |
+
speaker_embedding = np.repeat(
|
128 |
+
speaker_embedding, semantic_tokens.shape[1], axis=0)
|
129 |
+
|
130 |
+
# Speaker 2
|
131 |
+
if "text_2" in metadata:
|
132 |
+
_name = _name.split(".wav")[0] + "_2.wav"
|
133 |
+
acoustic_tokens_2 = self.load_quantization(_name)
|
134 |
+
acoustic_tokens_2 = np.pad(
|
135 |
+
acoustic_tokens_2, [[1, 0],[0,0]], constant_values=c.SPKR_2)
|
136 |
+
|
137 |
+
npy_path = self.datadir.parent / 'audios-speech-tokenizer/semantic' / (os.path.splitext(_name)[0] + ".npy") # noqa
|
138 |
+
semantic_tokens_2 = np.load(npy_path)[None]
|
139 |
+
semantic_tokens_2 = np.pad(
|
140 |
+
semantic_tokens_2,[[0,0], [1, 0]], constant_values=c.SPKR_2)
|
141 |
+
|
142 |
+
wav, _ = sf.read(dataname.split(".wav")[0] + "_2.wav")
|
143 |
+
audio = normalize(wav) * 0.95
|
144 |
+
speaker_embedding_2 = self.spkr_embedding(
|
145 |
+
{"waveform": torch.FloatTensor(audio).unsqueeze(0),
|
146 |
+
"sample_rate": self.hp.sample_rate,}
|
147 |
+
).reshape(1, -1)
|
148 |
+
speaker_embedding_2 = np.repeat(
|
149 |
+
speaker_embedding_2, semantic_tokens_2.shape[1], axis=0)
|
150 |
+
|
151 |
+
# Merge both speakers
|
152 |
+
acoustic_tokens = np.concatenate(
|
153 |
+
(acoustic_tokens, acoustic_tokens_2), axis=0)
|
154 |
+
semantic_tokens = np.concatenate(
|
155 |
+
(semantic_tokens, semantic_tokens_2), axis=1)
|
156 |
+
speaker_embedding = np.concatenate(
|
157 |
+
(speaker_embedding, speaker_embedding_2), axis=0)
|
158 |
+
|
159 |
+
speaker_embedding = speaker_embedding[:self.hp.max_length, :]
|
160 |
+
acoustic_tokens = acoustic_tokens[:self.hp.max_length, :]
|
161 |
+
semantic_tokens = semantic_tokens[:, :self.hp.max_length]
|
162 |
+
|
163 |
+
# # HACK - we have no 8 lvls pfb30
|
164 |
+
# acoustic_tokens = np.concatenate((semantic_tokens.T, acoustic_tokens), axis=1)
|
165 |
+
# # END HACK
|
166 |
+
|
167 |
+
return speaker_embedding, acoustic_tokens, acoustic_tokens, dataname, semantic_tokens # noqa
|
modules/__init__.py
ADDED
File without changes
|
modules/conformer.py
ADDED
@@ -0,0 +1,671 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Conformer definition adjusted given the Lucidrain's repo.
|
2 |
+
https://github.com/lucidrains/soundstorm-pytorch/blob/main/soundstorm_pytorch/soundstorm.py # noqa
|
3 |
+
|
4 |
+
Copyright PolyAI Limited.
|
5 |
+
"""
|
6 |
+
from collections import namedtuple
|
7 |
+
from functools import wraps
|
8 |
+
from typing import Dict, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from einops import rearrange, reduce
|
13 |
+
from einops.layers.torch import EinMix, Rearrange
|
14 |
+
from torch import einsum, nn
|
15 |
+
|
16 |
+
|
17 |
+
# rotary embedding
|
18 |
+
class RotaryEmbedding(nn.Module):
|
19 |
+
def __init__(self, dim, theta = 10000):
|
20 |
+
super().__init__()
|
21 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
22 |
+
self.register_buffer("inv_freq", inv_freq, persistent = False)
|
23 |
+
|
24 |
+
@property
|
25 |
+
def device(self):
|
26 |
+
return next(self.buffers()).device
|
27 |
+
|
28 |
+
def forward(self, seq_len):
|
29 |
+
t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq)
|
30 |
+
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
|
31 |
+
freqs = torch.cat((freqs, freqs), dim = -1)
|
32 |
+
return freqs
|
33 |
+
|
34 |
+
def rotate_half(x):
|
35 |
+
x1, x2 = x.chunk(2, dim=-1)
|
36 |
+
return torch.cat((-x2, x1), dim=-1)
|
37 |
+
|
38 |
+
def apply_rotary_pos_emb(pos, t):
|
39 |
+
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
|
40 |
+
|
41 |
+
|
42 |
+
# constants
|
43 |
+
EfficientAttentionConfig = namedtuple(
|
44 |
+
'EfficientAttentionConfig',
|
45 |
+
['enable_flash', 'enable_math', 'enable_mem_efficient']
|
46 |
+
)
|
47 |
+
|
48 |
+
# helpers
|
49 |
+
def exists(val):
|
50 |
+
return val is not None
|
51 |
+
|
52 |
+
def default(val, d):
|
53 |
+
return val if exists(val) else d
|
54 |
+
|
55 |
+
def divisible_by(numer, denom):
|
56 |
+
return (numer % denom) == 0
|
57 |
+
|
58 |
+
def calc_same_padding(kernel_size):
|
59 |
+
pad = kernel_size // 2
|
60 |
+
return (pad, pad - (kernel_size + 1) % 2)
|
61 |
+
|
62 |
+
def eval_decorator(fn):
|
63 |
+
@wraps(fn)
|
64 |
+
def inner(model, *args, **kwargs):
|
65 |
+
was_training = model.training
|
66 |
+
model.eval()
|
67 |
+
out = fn(model, *args, **kwargs)
|
68 |
+
model.train(was_training)
|
69 |
+
return out
|
70 |
+
return inner
|
71 |
+
|
72 |
+
|
73 |
+
def once(fn):
|
74 |
+
called = False
|
75 |
+
@wraps(fn)
|
76 |
+
def inner(x):
|
77 |
+
nonlocal called
|
78 |
+
if called:
|
79 |
+
return
|
80 |
+
called = True
|
81 |
+
return fn(x)
|
82 |
+
return inner
|
83 |
+
|
84 |
+
print_once = once(print)
|
85 |
+
|
86 |
+
|
87 |
+
# t5 relative positional bias
|
88 |
+
class T5RelativePositionBias(nn.Module):
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
scale = 1.,
|
92 |
+
num_buckets = 32,
|
93 |
+
max_distance = 128,
|
94 |
+
heads = 8
|
95 |
+
):
|
96 |
+
super().__init__()
|
97 |
+
self.scale = scale
|
98 |
+
self.num_buckets = num_buckets
|
99 |
+
self.max_distance = max_distance
|
100 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
101 |
+
|
102 |
+
@staticmethod
|
103 |
+
def _relative_position_bucket(
|
104 |
+
relative_position,
|
105 |
+
num_buckets = 32,
|
106 |
+
max_distance = 128
|
107 |
+
):
|
108 |
+
ret = 0
|
109 |
+
n = -relative_position
|
110 |
+
|
111 |
+
num_buckets //= 2
|
112 |
+
ret += (n < 0).long() * num_buckets
|
113 |
+
n = torch.abs(n)
|
114 |
+
|
115 |
+
max_exact = num_buckets // 2
|
116 |
+
is_small = n < max_exact
|
117 |
+
|
118 |
+
val_if_large = max_exact + (
|
119 |
+
torch.log(n.float() / max_exact) / math.log(
|
120 |
+
max_distance / max_exact) * (num_buckets - max_exact)
|
121 |
+
).long()
|
122 |
+
|
123 |
+
val_if_large = torch.min(
|
124 |
+
val_if_large,
|
125 |
+
torch.full_like(val_if_large, num_buckets - 1)
|
126 |
+
)
|
127 |
+
|
128 |
+
ret += torch.where(is_small, n, val_if_large)
|
129 |
+
return ret
|
130 |
+
|
131 |
+
@property
|
132 |
+
def device(self):
|
133 |
+
return next(self.parameters()).device
|
134 |
+
|
135 |
+
def forward(self, n):
|
136 |
+
pos = torch.arange(n, device = self.device).long()
|
137 |
+
rel_pos = rearrange(pos, 'j -> 1 j') - rearrange(pos, 'i -> i 1')
|
138 |
+
|
139 |
+
rp_bucket = self._relative_position_bucket(
|
140 |
+
rel_pos, num_buckets = self.num_buckets,
|
141 |
+
max_distance = self.max_distance)
|
142 |
+
values = self.relative_attention_bias(rp_bucket)
|
143 |
+
|
144 |
+
bias = rearrange(values, 'i j h -> h i j')
|
145 |
+
return bias * self.scale
|
146 |
+
|
147 |
+
|
148 |
+
# main class
|
149 |
+
class Attend(nn.Module):
|
150 |
+
def __init__(
|
151 |
+
self,
|
152 |
+
causal = False,
|
153 |
+
dropout = 0.,
|
154 |
+
flash = False
|
155 |
+
):
|
156 |
+
super().__init__()
|
157 |
+
self.dropout = dropout
|
158 |
+
self.attn_dropout = nn.Dropout(dropout)
|
159 |
+
|
160 |
+
self.causal = causal
|
161 |
+
self.flash = flash
|
162 |
+
|
163 |
+
# determine efficient attention configs for cuda and cpu
|
164 |
+
self.cpu_config = EfficientAttentionConfig(True, True, True)
|
165 |
+
self.cuda_config = None
|
166 |
+
|
167 |
+
if not torch.cuda.is_available() or not flash:
|
168 |
+
return
|
169 |
+
|
170 |
+
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
|
171 |
+
|
172 |
+
if device_properties.major == 8 and device_properties.minor == 0:
|
173 |
+
print_once('A100 GPU detected, using flash attention if input tensor is on cuda') # noqa
|
174 |
+
self.cuda_config = EfficientAttentionConfig(True, True, True)
|
175 |
+
else:
|
176 |
+
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') # noqa
|
177 |
+
self.cuda_config = EfficientAttentionConfig(False, True, True)
|
178 |
+
|
179 |
+
def get_mask(self, i, j, device):
|
180 |
+
return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1) # noqa
|
181 |
+
|
182 |
+
def flash_attn(self, q, k, v, mask = None, attn_bias = None):
|
183 |
+
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device # noqa
|
184 |
+
|
185 |
+
# single headed key / values
|
186 |
+
|
187 |
+
if k.ndim == 3:
|
188 |
+
k = rearrange(k, 'b n d -> b 1 n d')
|
189 |
+
|
190 |
+
if v.ndim == 3:
|
191 |
+
v = rearrange(v, 'b n d -> b 1 n d')
|
192 |
+
|
193 |
+
# Check if mask exists and expand to compatible shape
|
194 |
+
# The mask is B L, so it would have to be expanded to B H N L
|
195 |
+
if exists(mask) and mask.ndim != 4:
|
196 |
+
mask = rearrange(mask, 'b j -> b 1 1 j')
|
197 |
+
mask = mask.expand(-1, heads, q_len, -1)
|
198 |
+
|
199 |
+
# Check if there is a compatible device for flash attention
|
200 |
+
config = self.cuda_config if is_cuda else self.cpu_config
|
201 |
+
causal = self.causal
|
202 |
+
|
203 |
+
# handle attention bias
|
204 |
+
if exists(attn_bias):
|
205 |
+
mask_value = -torch.finfo(q.dtype).max // 2
|
206 |
+
causal_mask = self.get_mask(q_len, k_len, device)
|
207 |
+
attn_bias = attn_bias.masked_fill(causal_mask, mask_value)
|
208 |
+
|
209 |
+
if exists(mask):
|
210 |
+
attn_bias = attn_bias.masked_fill(~mask, mask_value)
|
211 |
+
|
212 |
+
mask = attn_bias
|
213 |
+
causal = False
|
214 |
+
|
215 |
+
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
|
216 |
+
with torch.backends.cuda.sdp_kernel(**config._asdict()):
|
217 |
+
out = F.scaled_dot_product_attention(
|
218 |
+
q, k, v,
|
219 |
+
attn_mask = mask,
|
220 |
+
dropout_p = self.dropout if self.training else 0.,
|
221 |
+
is_causal = causal
|
222 |
+
)
|
223 |
+
|
224 |
+
return out
|
225 |
+
|
226 |
+
def forward(self, q, k, v, mask = None, attn_bias = None):
|
227 |
+
"""
|
228 |
+
einstein notation
|
229 |
+
b - batch
|
230 |
+
h - heads
|
231 |
+
n, i, j - sequence length (base sequence length, source, target)
|
232 |
+
d - feature dimension
|
233 |
+
"""
|
234 |
+
|
235 |
+
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
|
236 |
+
|
237 |
+
scale = q.shape[-1] ** -0.5
|
238 |
+
|
239 |
+
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
|
240 |
+
|
241 |
+
if self.flash:
|
242 |
+
assert not exists(attn_bias)
|
243 |
+
return self.flash_attn(q, k, v, mask = mask)
|
244 |
+
|
245 |
+
# similarity
|
246 |
+
|
247 |
+
sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
|
248 |
+
|
249 |
+
# attention bias
|
250 |
+
|
251 |
+
if exists(attn_bias):
|
252 |
+
sim = sim + attn_bias
|
253 |
+
|
254 |
+
# causal mask
|
255 |
+
if self.causal:
|
256 |
+
causal_mask = self.get_mask(q_len, k_len, device)
|
257 |
+
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
|
258 |
+
|
259 |
+
# key padding mask
|
260 |
+
if exists(mask):
|
261 |
+
if mask.ndim != 4:
|
262 |
+
mask = rearrange(mask, 'b j -> b 1 1 j')
|
263 |
+
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
|
264 |
+
|
265 |
+
# attention
|
266 |
+
attn = sim.softmax(dim=-1)
|
267 |
+
attn = self.attn_dropout(attn)
|
268 |
+
|
269 |
+
# aggregate values
|
270 |
+
out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
|
271 |
+
|
272 |
+
return out
|
273 |
+
|
274 |
+
|
275 |
+
class Swish(nn.Module):
|
276 |
+
def forward(self, x):
|
277 |
+
return x * x.sigmoid()
|
278 |
+
|
279 |
+
|
280 |
+
class GLU(nn.Module):
|
281 |
+
def __init__(self, dim):
|
282 |
+
super().__init__()
|
283 |
+
self.dim = dim
|
284 |
+
|
285 |
+
def forward(self, x):
|
286 |
+
out, gate = x.chunk(2, dim=self.dim)
|
287 |
+
return out * gate.sigmoid()
|
288 |
+
|
289 |
+
|
290 |
+
class DepthWiseConv1d(nn.Module):
|
291 |
+
def __init__(self, chan_in, chan_out, kernel_size, padding):
|
292 |
+
super().__init__()
|
293 |
+
self.padding = padding
|
294 |
+
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in)
|
295 |
+
|
296 |
+
def forward(self, x):
|
297 |
+
x = F.pad(x, self.padding)
|
298 |
+
return self.conv(x)
|
299 |
+
|
300 |
+
|
301 |
+
class Scale(nn.Module):
|
302 |
+
def __init__(self, scale, fn):
|
303 |
+
super().__init__()
|
304 |
+
self.fn = fn
|
305 |
+
self.scale = scale
|
306 |
+
|
307 |
+
def forward(self, x, **kwargs):
|
308 |
+
return self.fn(x, **kwargs) * self.scale
|
309 |
+
|
310 |
+
|
311 |
+
class ChanLayerNorm(nn.Module):
|
312 |
+
def __init__(self, dim):
|
313 |
+
super().__init__()
|
314 |
+
self.gamma = nn.Parameter(torch.ones(1, dim, 1))
|
315 |
+
|
316 |
+
def forward(self, x):
|
317 |
+
eps = 1e-6 if x.dtype == torch.float32 else 1e-4
|
318 |
+
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
319 |
+
mean = torch.mean(x, dim = 1, keepdim = True)
|
320 |
+
return (x - mean) * var.clamp(min = eps).rsqrt() * self.gamma
|
321 |
+
|
322 |
+
|
323 |
+
class PreNorm(nn.Module):
|
324 |
+
def __init__(self, dim, fn):
|
325 |
+
super().__init__()
|
326 |
+
self.fn = fn
|
327 |
+
self.norm = nn.LayerNorm(dim)
|
328 |
+
|
329 |
+
def forward(self, x, **kwargs):
|
330 |
+
x = self.norm(x)
|
331 |
+
return self.fn(x, **kwargs)
|
332 |
+
|
333 |
+
|
334 |
+
class Attention(nn.Module):
|
335 |
+
def __init__(
|
336 |
+
self,
|
337 |
+
dim,
|
338 |
+
heads = 8,
|
339 |
+
dim_head = 64,
|
340 |
+
dropout = 0.,
|
341 |
+
flash = True
|
342 |
+
):
|
343 |
+
super().__init__()
|
344 |
+
inner_dim = dim_head * heads
|
345 |
+
self.heads= heads
|
346 |
+
self.scale = dim_head ** -0.5
|
347 |
+
|
348 |
+
self.attend = Attend(
|
349 |
+
flash = flash,
|
350 |
+
dropout = dropout
|
351 |
+
)
|
352 |
+
|
353 |
+
self.dropout = nn.Dropout(dropout)
|
354 |
+
|
355 |
+
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
356 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
357 |
+
self.to_out = nn.Linear(inner_dim, dim)
|
358 |
+
|
359 |
+
def forward(
|
360 |
+
self,
|
361 |
+
x,
|
362 |
+
context = None,
|
363 |
+
mask = None,
|
364 |
+
rotary_emb = None,
|
365 |
+
attn_bias = None
|
366 |
+
):
|
367 |
+
n, device, h, has_context = x.shape[-2], x.device, self.heads, exists(context)
|
368 |
+
context = default(context, x)
|
369 |
+
|
370 |
+
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
|
371 |
+
q, k, v = map(
|
372 |
+
lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
373 |
+
|
374 |
+
if exists(rotary_emb):
|
375 |
+
q = apply_rotary_pos_emb(rotary_emb, q)
|
376 |
+
k = apply_rotary_pos_emb(rotary_emb, k)
|
377 |
+
|
378 |
+
out = self.attend(q, k, v, mask = mask, attn_bias = attn_bias)
|
379 |
+
|
380 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
381 |
+
return self.to_out(out)
|
382 |
+
|
383 |
+
|
384 |
+
class FeedForward(nn.Module):
|
385 |
+
def __init__(
|
386 |
+
self,
|
387 |
+
dim,
|
388 |
+
mult = 4,
|
389 |
+
dropout = 0.
|
390 |
+
):
|
391 |
+
super().__init__()
|
392 |
+
self.net = nn.Sequential(
|
393 |
+
nn.Linear(dim, dim * mult),
|
394 |
+
Swish(),
|
395 |
+
nn.Dropout(dropout),
|
396 |
+
nn.Linear(dim * mult, dim),
|
397 |
+
nn.Dropout(dropout)
|
398 |
+
)
|
399 |
+
|
400 |
+
def forward(self, x):
|
401 |
+
return self.net(x)
|
402 |
+
|
403 |
+
|
404 |
+
class ConformerConvModule(nn.Module):
|
405 |
+
def __init__(
|
406 |
+
self,
|
407 |
+
dim,
|
408 |
+
causal = False,
|
409 |
+
expansion_factor = 2,
|
410 |
+
kernel_size = 31,
|
411 |
+
dropout = 0.
|
412 |
+
):
|
413 |
+
super().__init__()
|
414 |
+
|
415 |
+
inner_dim = dim * expansion_factor
|
416 |
+
padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
|
417 |
+
|
418 |
+
self.net = nn.Sequential(
|
419 |
+
nn.LayerNorm(dim),
|
420 |
+
Rearrange('b n c -> b c n'),
|
421 |
+
nn.Conv1d(dim, inner_dim * 2, 1),
|
422 |
+
GLU(dim=1),
|
423 |
+
DepthWiseConv1d(
|
424 |
+
inner_dim, inner_dim, kernel_size = kernel_size,
|
425 |
+
padding = padding
|
426 |
+
),
|
427 |
+
Swish(),
|
428 |
+
ChanLayerNorm(inner_dim),
|
429 |
+
nn.Conv1d(inner_dim, dim, 1),
|
430 |
+
Rearrange('b c n -> b n c'),
|
431 |
+
nn.Dropout(dropout)
|
432 |
+
)
|
433 |
+
|
434 |
+
def forward(self, x):
|
435 |
+
return self.net(x)
|
436 |
+
|
437 |
+
|
438 |
+
# Conformer Block
|
439 |
+
class ConformerBlock(nn.Module):
|
440 |
+
def __init__(
|
441 |
+
self,
|
442 |
+
*,
|
443 |
+
dim,
|
444 |
+
dim_head = 64,
|
445 |
+
heads = 8,
|
446 |
+
ff_mult = 4,
|
447 |
+
conv_expansion_factor = 2,
|
448 |
+
conv_kernel_size = 31,
|
449 |
+
attn_dropout = 0.,
|
450 |
+
attn_flash = True,
|
451 |
+
ff_dropout = 0.,
|
452 |
+
conv_dropout = 0.,
|
453 |
+
conv_causal = False
|
454 |
+
):
|
455 |
+
super().__init__()
|
456 |
+
self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
|
457 |
+
self.attn = Attention(
|
458 |
+
dim = dim, dim_head = dim_head, heads = heads,
|
459 |
+
dropout = attn_dropout, flash = attn_flash
|
460 |
+
)
|
461 |
+
self.conv = ConformerConvModule(
|
462 |
+
dim = dim, causal = conv_causal,
|
463 |
+
expansion_factor = conv_expansion_factor,
|
464 |
+
kernel_size = conv_kernel_size, dropout = conv_dropout
|
465 |
+
)
|
466 |
+
self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
|
467 |
+
|
468 |
+
self.attn = PreNorm(dim, self.attn)
|
469 |
+
self.ff1 = Scale(0.5, PreNorm(dim, self.ff1))
|
470 |
+
self.ff2 = Scale(0.5, PreNorm(dim, self.ff2))
|
471 |
+
|
472 |
+
self.post_norm = nn.LayerNorm(dim)
|
473 |
+
|
474 |
+
def forward(
|
475 |
+
self,
|
476 |
+
x,
|
477 |
+
mask = None,
|
478 |
+
rotary_emb = None,
|
479 |
+
attn_bias = None
|
480 |
+
):
|
481 |
+
x = self.ff1(x) + x
|
482 |
+
x = self.attn(x, mask = mask, rotary_emb = rotary_emb, attn_bias = attn_bias) + x # noqa
|
483 |
+
x = self.conv(x) + x
|
484 |
+
x = self.ff2(x) + x
|
485 |
+
x = self.post_norm(x)
|
486 |
+
return x
|
487 |
+
|
488 |
+
|
489 |
+
# Conformer
|
490 |
+
class Conformer(nn.Module):
|
491 |
+
def __init__(
|
492 |
+
self,
|
493 |
+
dim,
|
494 |
+
*,
|
495 |
+
num_layers,
|
496 |
+
dim_head = 64,
|
497 |
+
heads = 8,
|
498 |
+
ff_mult = 4,
|
499 |
+
conv_expansion_factor = 2,
|
500 |
+
conv_kernel_size = 31,
|
501 |
+
attn_dropout = 0.,
|
502 |
+
ff_dropout = 0.,
|
503 |
+
conv_dropout = 0.,
|
504 |
+
conv_causal = False,
|
505 |
+
attn_flash = True,
|
506 |
+
t5_rel_pos_bias = False
|
507 |
+
):
|
508 |
+
super().__init__()
|
509 |
+
|
510 |
+
assert not (t5_rel_pos_bias and attn_flash), 'flash attention is not compatible with learned bias' # noqa
|
511 |
+
|
512 |
+
self.dim = dim
|
513 |
+
self.layers = nn.ModuleList([])
|
514 |
+
|
515 |
+
self.rotary_emb = RotaryEmbedding(
|
516 |
+
dim_head) if not t5_rel_pos_bias else None
|
517 |
+
self.rel_pos_bias = T5RelativePositionBias(
|
518 |
+
dim_head ** 0.5, heads = heads) if t5_rel_pos_bias else None
|
519 |
+
|
520 |
+
for _ in range(num_layers):
|
521 |
+
self.layers.append(ConformerBlock(
|
522 |
+
dim = dim,
|
523 |
+
dim_head = dim_head,
|
524 |
+
heads = heads,
|
525 |
+
ff_mult = ff_mult,
|
526 |
+
conv_expansion_factor = conv_expansion_factor,
|
527 |
+
conv_kernel_size = conv_kernel_size,
|
528 |
+
attn_dropout = attn_dropout,
|
529 |
+
ff_dropout = ff_dropout,
|
530 |
+
conv_dropout = conv_dropout,
|
531 |
+
conv_causal = conv_causal,
|
532 |
+
attn_flash = attn_flash
|
533 |
+
))
|
534 |
+
|
535 |
+
def forward(self, x, mask = None):
|
536 |
+
seq_len = x.shape[-2]
|
537 |
+
|
538 |
+
rotary_emb = self.rotary_emb(seq_len) if exists(self.rotary_emb) else None # noqa
|
539 |
+
attn_bias = self.rel_pos_bias(seq_len) if exists(self.rel_pos_bias) else None #noqa
|
540 |
+
|
541 |
+
for block in self.layers:
|
542 |
+
x = block(
|
543 |
+
x,
|
544 |
+
mask = mask,
|
545 |
+
rotary_emb = rotary_emb,
|
546 |
+
attn_bias = attn_bias
|
547 |
+
)
|
548 |
+
return x
|
549 |
+
|
550 |
+
|
551 |
+
# conformer with sum reduction across quantized tokens at the beginning,
|
552 |
+
# along with heads
|
553 |
+
class ConformerWrapper(nn.Module):
|
554 |
+
def __init__(
|
555 |
+
self,
|
556 |
+
*,
|
557 |
+
codebook_size,
|
558 |
+
num_quantizers,
|
559 |
+
conformer: Union[Conformer, Dict[str, any]],
|
560 |
+
grouped_quantizers = 1
|
561 |
+
):
|
562 |
+
super().__init__()
|
563 |
+
self.conformer = conformer
|
564 |
+
|
565 |
+
if isinstance(conformer, dict):
|
566 |
+
self.conformer = Conformer(**self.conformer)
|
567 |
+
|
568 |
+
dim = self.conformer.dim
|
569 |
+
|
570 |
+
self.embedding_proj = nn.Sequential(
|
571 |
+
nn.Linear(dim * grouped_quantizers, dim),
|
572 |
+
nn.LayerNorm(dim)
|
573 |
+
) if grouped_quantizers > 1 else nn.Identity()
|
574 |
+
|
575 |
+
num_codes_with_mask = codebook_size + 1
|
576 |
+
num_effective_quantizers = num_quantizers * grouped_quantizers
|
577 |
+
|
578 |
+
self.code_embeds = nn.Embedding(
|
579 |
+
num_codes_with_mask * num_effective_quantizers, dim)
|
580 |
+
|
581 |
+
self.register_buffer(
|
582 |
+
'quantizer_offsets',
|
583 |
+
torch.arange(num_effective_quantizers) * num_codes_with_mask,
|
584 |
+
persistent = False
|
585 |
+
)
|
586 |
+
self.register_buffer(
|
587 |
+
'mask_tokens', self.quantizer_offsets + num_codes_with_mask,
|
588 |
+
persistent = False
|
589 |
+
)
|
590 |
+
|
591 |
+
self.dim = dim
|
592 |
+
self.codebook_size = codebook_size
|
593 |
+
|
594 |
+
self.num_codes_with_mask = num_codes_with_mask
|
595 |
+
self.num_quantizers = num_quantizers
|
596 |
+
self.grouped_quantizers = grouped_quantizers
|
597 |
+
|
598 |
+
self.heads = nn.Sequential(
|
599 |
+
nn.Linear(dim, dim * num_effective_quantizers),
|
600 |
+
Rearrange('b n (h d) -> b (n h) d', h = num_effective_quantizers)
|
601 |
+
)
|
602 |
+
|
603 |
+
# each quantizer codebook would require its own logits weight
|
604 |
+
# and bias matrices
|
605 |
+
# the amazing einops makes this easy with 'EinMix'
|
606 |
+
self.to_logits = nn.Sequential(
|
607 |
+
nn.LayerNorm(dim),
|
608 |
+
Rearrange('b (n gq) d -> b n gq d', gq = num_effective_quantizers),
|
609 |
+
EinMix(
|
610 |
+
'b n gq d -> b n gq l',
|
611 |
+
weight_shape = 'gq d l',
|
612 |
+
bias_shape = 'gq l',
|
613 |
+
gq = num_effective_quantizers,
|
614 |
+
l = codebook_size,
|
615 |
+
d = dim
|
616 |
+
),
|
617 |
+
Rearrange('b ... d -> b (...) d')
|
618 |
+
)
|
619 |
+
|
620 |
+
def forward(
|
621 |
+
self,
|
622 |
+
x,
|
623 |
+
*,
|
624 |
+
mask = None,
|
625 |
+
cond = None,
|
626 |
+
sum_embeds = None,
|
627 |
+
return_embeddings = False,
|
628 |
+
return_logits_and_embeddings = False
|
629 |
+
):
|
630 |
+
"""
|
631 |
+
einops notation:
|
632 |
+
b - batch
|
633 |
+
n - sequence
|
634 |
+
g - groups
|
635 |
+
q - quantizers
|
636 |
+
d - feature dimension
|
637 |
+
"""
|
638 |
+
|
639 |
+
n, q, g = x.shape[-1], self.num_quantizers, self.grouped_quantizers
|
640 |
+
assert divisible_by(n, g * q), 'sequence must be divisible by number of quantizers' # noqa
|
641 |
+
|
642 |
+
x = rearrange(x, 'b (n gq) -> b n gq', gq = g * q)
|
643 |
+
x = x + self.quantizer_offsets
|
644 |
+
|
645 |
+
x = self.code_embeds(x)
|
646 |
+
|
647 |
+
x = reduce(x, 'b n (g q) d -> b n (g d)', 'sum', g = g)
|
648 |
+
|
649 |
+
x = self.embedding_proj(x)
|
650 |
+
|
651 |
+
if exists(sum_embeds):
|
652 |
+
x = x + sum_embeds
|
653 |
+
|
654 |
+
if exists(cond):
|
655 |
+
if cond.ndim == 2:
|
656 |
+
cond = rearrange(cond, 'b d -> b 1 d')
|
657 |
+
|
658 |
+
x = x + cond
|
659 |
+
|
660 |
+
x = self.conformer(x, mask = mask)
|
661 |
+
embeds = self.heads(x)
|
662 |
+
|
663 |
+
if return_embeddings or not exists(self.to_logits):
|
664 |
+
return embeds
|
665 |
+
|
666 |
+
logits = self.to_logits(embeds)
|
667 |
+
|
668 |
+
if return_logits_and_embeddings:
|
669 |
+
return logits, embeds
|
670 |
+
|
671 |
+
return logits
|
modules/masking_logic.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Masking and sampling logic adapted from MaskGIT original paper:
|
2 |
+
https://github.com/google-research/maskgit
|
3 |
+
|
4 |
+
Copyright PolyAI Limited.
|
5 |
+
"""
|
6 |
+
from dataclasses import dataclass
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
@dataclass
|
14 |
+
class State:
|
15 |
+
"""Holds decoding state data."""
|
16 |
+
# The position of the decoding loop in the length dimension.
|
17 |
+
cur_index: None
|
18 |
+
# The active sequence log probabilities and finished sequence scores.
|
19 |
+
cur_seqs: None
|
20 |
+
final_seqs: None
|
21 |
+
|
22 |
+
|
23 |
+
def state_init(init_indices, num_iter, start_iter=0):
|
24 |
+
"""Initializes the decoding state data structure."""
|
25 |
+
cur_index_0 = start_iter
|
26 |
+
cur_seqs_0 = init_indices
|
27 |
+
final_seqs_0 = torch.unsqueeze(init_indices, 1)
|
28 |
+
final_seqs_0 = torch.tile(final_seqs_0, (1, num_iter, 1))
|
29 |
+
return State(
|
30 |
+
cur_index=cur_index_0, cur_seqs=cur_seqs_0, final_seqs=final_seqs_0)
|
31 |
+
|
32 |
+
|
33 |
+
def schedule(ratio, method="cosine"):
|
34 |
+
if method == "uniform":
|
35 |
+
mask_ratio = 1. - ratio
|
36 |
+
elif "pow" in method:
|
37 |
+
exponent = float(method.replace("pow", ""))
|
38 |
+
mask_ratio = 1. - ratio**exponent
|
39 |
+
elif method == "cosine":
|
40 |
+
mask_ratio = np.cos(ratio * (np.pi/2))
|
41 |
+
|
42 |
+
mask_ratio = np.clip(mask_ratio, 1e-6, 1.)
|
43 |
+
return mask_ratio
|
44 |
+
|
45 |
+
|
46 |
+
def mask_by_random_topk(mask_len, probs, temperature=1.0):
|
47 |
+
noise = gumbel_noise_like(probs)
|
48 |
+
confidence = torch.log(probs) + temperature * noise
|
49 |
+
sorted_confidence, _ = torch.sort(confidence, dim=-1)
|
50 |
+
# Obtains cut off threshold given the mask lengths.
|
51 |
+
cut_off = torch.take_along_dim(sorted_confidence, mask_len.long(), dim=-1)
|
52 |
+
# Masks tokens with lower confidence.
|
53 |
+
masking = (confidence < cut_off)
|
54 |
+
return masking
|
55 |
+
|
56 |
+
|
57 |
+
def gumbel_noise_like(t):
|
58 |
+
noise = torch.zeros_like(t).uniform_(1e-20, 1)
|
59 |
+
return -torch.log(-torch.log(noise))
|
60 |
+
|
61 |
+
|
62 |
+
def sample_from_logits(
|
63 |
+
logits,
|
64 |
+
sample: bool = True,
|
65 |
+
temperature: float = 1.0,
|
66 |
+
top_k: int = None,
|
67 |
+
top_p: float = None,
|
68 |
+
return_probs: bool = False
|
69 |
+
):
|
70 |
+
shp = logits.shape[:-1]
|
71 |
+
|
72 |
+
# Apply top_k sampling
|
73 |
+
if top_k is not None:
|
74 |
+
v, _ = logits.topk(top_k)
|
75 |
+
logits[logits < v[..., [-1]]] = -float("inf")
|
76 |
+
|
77 |
+
# Apply top_p (nucleus) sampling
|
78 |
+
if top_p is not None and top_p < 1.0:
|
79 |
+
v, sorted_indices = logits.sort(descending=True)
|
80 |
+
cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1)
|
81 |
+
|
82 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
83 |
+
# Right shift indices_to_remove to keep 1st token over threshold
|
84 |
+
sorted_indices_to_remove = F.pad(
|
85 |
+
sorted_indices_to_remove, (1, 0), value=False)[..., :-1]
|
86 |
+
|
87 |
+
# Compute indices_to_remove in unsorted array
|
88 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
89 |
+
-1, sorted_indices, sorted_indices_to_remove
|
90 |
+
)
|
91 |
+
|
92 |
+
logits[indices_to_remove] = -float("inf")
|
93 |
+
|
94 |
+
# Perform multinomial sampling after normalizing logits
|
95 |
+
probs = (
|
96 |
+
F.softmax(logits / temperature, dim=-1)
|
97 |
+
if temperature > 0
|
98 |
+
else logits.softmax(dim=-1)
|
99 |
+
)
|
100 |
+
token = (
|
101 |
+
probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp)
|
102 |
+
if sample
|
103 |
+
else logits.argmax(-1)
|
104 |
+
)
|
105 |
+
|
106 |
+
if return_probs:
|
107 |
+
token_probs = probs.take_along_dim(
|
108 |
+
token.unsqueeze(-1), dim=-1).squeeze(-1)
|
109 |
+
return token, token_probs
|
110 |
+
else:
|
111 |
+
return token
|
modules/s2a_model.py
ADDED
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A2S model definition.
|
2 |
+
|
3 |
+
Copyright PolyAI Limited.
|
4 |
+
"""
|
5 |
+
from typing import Union
|
6 |
+
|
7 |
+
import pytorch_lightning as pl
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.optim as optim
|
12 |
+
from einops import rearrange
|
13 |
+
|
14 |
+
import constants as c
|
15 |
+
from modules import masking_logic
|
16 |
+
from modules.conformer import Conformer
|
17 |
+
from modules.masking_logic import (State, mask_by_random_topk,
|
18 |
+
sample_from_logits, state_init)
|
19 |
+
from utils import load_checkpoint
|
20 |
+
|
21 |
+
|
22 |
+
class Pheme(pl.LightningModule):
|
23 |
+
def __init__(self, hp):
|
24 |
+
super().__init__()
|
25 |
+
self.hp = hp
|
26 |
+
self.model = TTSConformer(hp)
|
27 |
+
self.cross_entropy = nn.CrossEntropyLoss(
|
28 |
+
label_smoothing=self.hp.label_smoothing,
|
29 |
+
ignore_index=self.hp.n_codes
|
30 |
+
)
|
31 |
+
if self.hp.pretrained_path:
|
32 |
+
self.load()
|
33 |
+
else:
|
34 |
+
self.apply(self.init_weights)
|
35 |
+
|
36 |
+
if self.hp.only_inference:
|
37 |
+
self.model.eval()
|
38 |
+
|
39 |
+
self.save_hyperparameters()
|
40 |
+
|
41 |
+
def load(self):
|
42 |
+
state_dict = load_checkpoint(self.hp.pretrained_path)
|
43 |
+
print(f"Parameters loaded from {self.hp.pretrained_path}")
|
44 |
+
self.load_state_dict(state_dict, strict=True)
|
45 |
+
|
46 |
+
def init_weights(self, module):
|
47 |
+
if isinstance(module, nn.Linear):
|
48 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
49 |
+
if module.bias is not None:
|
50 |
+
module.bias.data.zero_()
|
51 |
+
if isinstance(module, nn.Embedding):
|
52 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
53 |
+
module._fill_padding_idx_with_zero()
|
54 |
+
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
55 |
+
module.bias.data.zero_()
|
56 |
+
module.weight.data.fill_(1.0)
|
57 |
+
elif isinstance(module, nn.Conv1d):
|
58 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
59 |
+
if module.bias is not None:
|
60 |
+
module.bias.data.zero_()
|
61 |
+
|
62 |
+
def configure_optimizers(self):
|
63 |
+
optimizer_adam = optim.AdamW(
|
64 |
+
self.parameters(), lr=self.hp.lr,
|
65 |
+
betas=(self.hp.adam_beta1, self.hp.adam_beta2))
|
66 |
+
|
67 |
+
# Learning rate scheduler
|
68 |
+
num_training_steps = self.hp.training_step
|
69 |
+
num_warmup_steps = self.hp.warmup_step
|
70 |
+
num_flat_steps = int(self.hp.optim_flat_percent * num_training_steps)
|
71 |
+
|
72 |
+
def lambda_lr(current_step: int):
|
73 |
+
if current_step < num_warmup_steps:
|
74 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
75 |
+
elif current_step < (num_warmup_steps + num_flat_steps):
|
76 |
+
return 1.0
|
77 |
+
return max(
|
78 |
+
0.0,
|
79 |
+
float(num_training_steps - current_step)
|
80 |
+
/ float(
|
81 |
+
max(1, num_training_steps - (num_warmup_steps + num_flat_steps)) # noqa
|
82 |
+
),
|
83 |
+
)
|
84 |
+
|
85 |
+
scheduler_adam = {
|
86 |
+
"scheduler": optim.lr_scheduler.LambdaLR(
|
87 |
+
optimizer_adam, lambda_lr),
|
88 |
+
"interval": "step",
|
89 |
+
}
|
90 |
+
return [optimizer_adam], [scheduler_adam]
|
91 |
+
|
92 |
+
def top_k_accuracy(self, y_true, y_pred_probabilities, k):
|
93 |
+
_, sorted_indices = torch.sort(y_pred_probabilities, descending=True)
|
94 |
+
|
95 |
+
# Get the top-k predictions
|
96 |
+
top_k_indices = sorted_indices[:, :k]
|
97 |
+
expanded_y_true = y_true.unsqueeze(1).expand_as(top_k_indices)
|
98 |
+
|
99 |
+
# Check if true labels exist in top-k predictions
|
100 |
+
hits = torch.sum(torch.eq(top_k_indices, expanded_y_true))
|
101 |
+
accuracy = hits.item() / (len(y_true) + 1e-7)
|
102 |
+
|
103 |
+
return accuracy
|
104 |
+
|
105 |
+
def training_step(self, batch, batch_idx):
|
106 |
+
# Sample training level
|
107 |
+
rvq_level = torch.randint(
|
108 |
+
0, min(self.hp.first_n_lvls, self.hp.n_cluster_groups),(1,)).item()
|
109 |
+
|
110 |
+
target, chosen_tokens, _, _ = self.model(
|
111 |
+
batch["tts_quantize_input"], rvq_level, batch["semantic_tokens"],
|
112 |
+
batch["quantization_lengths"],
|
113 |
+
speaker_emb=batch["speaker"],
|
114 |
+
min_seq_length=batch["quantization_lengths"].min().item())
|
115 |
+
|
116 |
+
# Mask targets and labels
|
117 |
+
mask = chosen_tokens
|
118 |
+
target = target[mask]
|
119 |
+
|
120 |
+
labels = batch["tts_quantize_input"][:, :, rvq_level]
|
121 |
+
labels = labels[mask]
|
122 |
+
|
123 |
+
loss = self.cross_entropy(target, labels)
|
124 |
+
acc = (target.argmax(-1) == labels).float().mean()
|
125 |
+
self.log("train/loss", loss, on_step=True, prog_bar=True)
|
126 |
+
self.log("train/acc", acc, on_step=True, prog_bar=True)
|
127 |
+
self.log(
|
128 |
+
f"train/acc_lvl_{rvq_level}", acc, on_step=True, prog_bar=False)
|
129 |
+
|
130 |
+
return loss
|
131 |
+
|
132 |
+
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
133 |
+
speaker_emb = batch["speaker"]
|
134 |
+
acoustic_tokens = batch["tts_quantize_input"]
|
135 |
+
semantic_tokens = batch["semantic_tokens"]
|
136 |
+
|
137 |
+
if self.hp.only_inference:
|
138 |
+
self.inference(
|
139 |
+
acoustic_tokens, semantic_tokens, self.hp.first_n_lvls)
|
140 |
+
else:
|
141 |
+
rvq_level = torch.randint(
|
142 |
+
0, min(self.hp.first_n_lvls, self.hp.n_cluster_groups),(1,)
|
143 |
+
).item()
|
144 |
+
|
145 |
+
# FIXME: edge case
|
146 |
+
if len(semantic_tokens.shape) == 3:
|
147 |
+
semantic_tokens = rearrange(semantic_tokens, "B 1 T -> B T")
|
148 |
+
|
149 |
+
target, chosen_tokens, _, _ = self.model(
|
150 |
+
acoustic_tokens, rvq_level, semantic_tokens,
|
151 |
+
torch.tensor([acoustic_tokens.shape[1]]).to(self.device),
|
152 |
+
speaker_emb=speaker_emb,
|
153 |
+
min_seq_length=acoustic_tokens.shape[1]
|
154 |
+
)
|
155 |
+
|
156 |
+
target = target[chosen_tokens]
|
157 |
+
labels = acoustic_tokens[:, :, rvq_level][chosen_tokens]
|
158 |
+
loss = self.cross_entropy(target, labels)
|
159 |
+
|
160 |
+
acc = (target.argmax(-1) == labels).float().mean()
|
161 |
+
acc_5 = self.top_k_accuracy(labels, target, 5)
|
162 |
+
|
163 |
+
self.log(
|
164 |
+
f"val/dataset_{dataloader_idx}/loss",
|
165 |
+
loss,
|
166 |
+
on_epoch=True,
|
167 |
+
logger=True,
|
168 |
+
add_dataloader_idx=False,
|
169 |
+
)
|
170 |
+
self.log(
|
171 |
+
f"val/dataset_{dataloader_idx}/acc_lvl",
|
172 |
+
acc,
|
173 |
+
on_epoch=True,
|
174 |
+
logger=True,
|
175 |
+
add_dataloader_idx=False,
|
176 |
+
)
|
177 |
+
self.log(
|
178 |
+
f"val/dataset_{dataloader_idx}/acc_lvl_{rvq_level}",
|
179 |
+
acc,
|
180 |
+
on_epoch=True,
|
181 |
+
logger=True,
|
182 |
+
add_dataloader_idx=False,
|
183 |
+
)
|
184 |
+
self.log(
|
185 |
+
f"val/dataset_{dataloader_idx}/acc_top_5",
|
186 |
+
acc_5,
|
187 |
+
on_epoch=True,
|
188 |
+
logger=True,
|
189 |
+
add_dataloader_idx=False,
|
190 |
+
)
|
191 |
+
self.log(
|
192 |
+
f"val/dataset_{dataloader_idx}/acc_top_5_lvl_{rvq_level}",
|
193 |
+
acc_5,
|
194 |
+
on_epoch=True,
|
195 |
+
logger=True,
|
196 |
+
add_dataloader_idx=False,
|
197 |
+
)
|
198 |
+
|
199 |
+
def compute_stats(self, logits, labels, mask_ratio=0, rvq_level=0):
|
200 |
+
acc = (logits.argmax(-1) == labels).float().mean()
|
201 |
+
acc_5 = self.top_k_accuracy(labels, logits, 5)
|
202 |
+
acc_10 = self.top_k_accuracy(labels, logits, 10)
|
203 |
+
|
204 |
+
idx = torch.randperm(logits.shape[0])
|
205 |
+
logits_shuffled = logits[idx]
|
206 |
+
random = self.top_k_accuracy(labels, logits_shuffled, 10)
|
207 |
+
print(f"Mask ratio: {mask_ratio}, Level {rvq_level}: acc {acc},"
|
208 |
+
f"acc 5 {acc_5}, acc 10 {acc_10}, quasi random {random}")
|
209 |
+
|
210 |
+
|
211 |
+
class TTSConformer(pl.LightningModule):
|
212 |
+
def __init__(self, hp):
|
213 |
+
super().__init__()
|
214 |
+
self.hp = hp
|
215 |
+
self.padding_id = self.hp.n_codes
|
216 |
+
|
217 |
+
additional_codes = [c.PAD, c.SPKR_1, c.SPKR_2]
|
218 |
+
|
219 |
+
self.embedding = nn.ModuleList(
|
220 |
+
[
|
221 |
+
nn.Embedding(
|
222 |
+
self.hp.n_codes + len(additional_codes),
|
223 |
+
self.hp.hidden_size,
|
224 |
+
padding_idx=self.padding_id)
|
225 |
+
for _ in range(self.hp.n_cluster_groups)
|
226 |
+
]
|
227 |
+
)
|
228 |
+
|
229 |
+
# Additional modules
|
230 |
+
self.semantic_embedding = nn.Embedding(
|
231 |
+
self.hp.n_semantic_codes + len(additional_codes),
|
232 |
+
self.hp.hidden_size,
|
233 |
+
padding_idx=self.padding_id)
|
234 |
+
|
235 |
+
if self.hp.use_spkr_emb:
|
236 |
+
self.spkr_linear = nn.Linear(c.SPKR_EMB_SIZE, self.hp.hidden_size)
|
237 |
+
|
238 |
+
self.conformer = Conformer(
|
239 |
+
dim=self.hp.hidden_size,
|
240 |
+
num_layers=self.hp.enc_nlayers,
|
241 |
+
heads=self.hp.nheads,
|
242 |
+
dim_head=64,
|
243 |
+
ff_mult=4, # 512*4=2048
|
244 |
+
conv_expansion_factor=2,
|
245 |
+
conv_kernel_size=self.hp.depthwise_conv_kernel_size,
|
246 |
+
attn_dropout=self.hp.dropout,
|
247 |
+
ff_dropout=self.hp.dropout,
|
248 |
+
conv_dropout=self.hp.dropout,
|
249 |
+
attn_flash=True,
|
250 |
+
t5_rel_pos_bias=False
|
251 |
+
)
|
252 |
+
|
253 |
+
self.heads = nn.ModuleList(
|
254 |
+
[
|
255 |
+
nn.Linear(
|
256 |
+
self.hp.hidden_size,
|
257 |
+
self.hp.n_codes + len(additional_codes)
|
258 |
+
)
|
259 |
+
for _ in range(self.hp.n_cluster_groups)
|
260 |
+
]
|
261 |
+
)
|
262 |
+
|
263 |
+
def build_mask_from_lengths(self, length, max_len=None):
|
264 |
+
max_len = max_len or length.max().item()
|
265 |
+
mask = torch.arange(
|
266 |
+
max_len, device=length.device)[None, :] >= length[:, None]
|
267 |
+
return mask.bool()
|
268 |
+
|
269 |
+
@torch.no_grad()
|
270 |
+
def create_mask(
|
271 |
+
self, B, T, lengths, mask_ratio=None, start_t=None,
|
272 |
+
min_seq_length=None
|
273 |
+
):
|
274 |
+
# 1. Define the random length of condition tokens given the shortest
|
275 |
+
# audio in the batch
|
276 |
+
if start_t is None:
|
277 |
+
start_t = torch.randint(1, min_seq_length - 1, (1,)).item()
|
278 |
+
|
279 |
+
# 2. Mask other tokens - sample different masking levels per
|
280 |
+
if mask_ratio is None:
|
281 |
+
ratio = torch.rand(1).item()
|
282 |
+
mask_ratio = masking_logic.schedule(ratio)
|
283 |
+
|
284 |
+
# Create a random tensor with values between 0 and 1
|
285 |
+
random_tensor = torch.rand(
|
286 |
+
(B, T - start_t), dtype=torch.float).to(self.device)
|
287 |
+
# Create a mask where values less than p are set to True
|
288 |
+
initial_mask = random_tensor < mask_ratio
|
289 |
+
length_mask = self.build_mask_from_lengths(
|
290 |
+
lengths - start_t, T - start_t)
|
291 |
+
# we can't pick up tokens past token lengths
|
292 |
+
initial_mask = torch.logical_and(initial_mask, ~length_mask)
|
293 |
+
|
294 |
+
# Constrain ratio to always include some samples
|
295 |
+
# If all are False let's pick up at least one:
|
296 |
+
if torch.sum(initial_mask) == 0:
|
297 |
+
choose_steps = torch.randint(low=0, high=(T - start_t), size=(B,))
|
298 |
+
initial_mask[torch.arange(B), choose_steps] = torch.tensor(
|
299 |
+
True, device=self.device)
|
300 |
+
|
301 |
+
# 3. Add condition tokens containing information
|
302 |
+
acoustic_token_mask = torch.cat(
|
303 |
+
(torch.full((B, start_t), False, device=self.device), initial_mask), # noqa
|
304 |
+
1
|
305 |
+
)
|
306 |
+
|
307 |
+
return acoustic_token_mask, start_t, mask_ratio
|
308 |
+
|
309 |
+
def process_input(
|
310 |
+
self, data, lengths, rvq_level, min_seq_length=None,
|
311 |
+
mask_ratio=None, start_t=None, acoustic_token_mask=None
|
312 |
+
):
|
313 |
+
"""
|
314 |
+
data: (B, T, code_level, D)
|
315 |
+
rvq_level: int
|
316 |
+
"""
|
317 |
+
B = data.size(0)
|
318 |
+
T = data.size(1)
|
319 |
+
level_data = data[:, :, rvq_level, :] # [B, T, C, D] -> [B, T, D]
|
320 |
+
|
321 |
+
# Choose acoustic tokens to mask
|
322 |
+
if acoustic_token_mask is None:
|
323 |
+
acoustic_token_mask, start_t, mask_ratio = self.create_mask(
|
324 |
+
B, T, lengths, mask_ratio=mask_ratio, start_t=start_t,
|
325 |
+
min_seq_length=min_seq_length)
|
326 |
+
# Remove code information from chosen tokens
|
327 |
+
level_data[acoustic_token_mask, :] = 0
|
328 |
+
|
329 |
+
# Embed only lower rvq_level
|
330 |
+
lower_code_data = data[:, :, :rvq_level, :].sum(dim=2)
|
331 |
+
|
332 |
+
# Combine with chosen tokens at rvq_level.
|
333 |
+
# Note: all tokens at rvq_level+1: will be discarded.
|
334 |
+
summed_data = torch.add(lower_code_data, level_data)
|
335 |
+
|
336 |
+
return summed_data, acoustic_token_mask, mask_ratio, start_t
|
337 |
+
|
338 |
+
def forward(
|
339 |
+
self, x, code_level, semantic_tokens, lengths,
|
340 |
+
speaker_emb=None, min_seq_length=10, mask_ratio=None, start_t=None,
|
341 |
+
acoustic_token_mask=None
|
342 |
+
):
|
343 |
+
# FIXME: parallelize this
|
344 |
+
batch = []
|
345 |
+
for lvl, embed in enumerate(self.embedding[:(code_level + 1)]):
|
346 |
+
batch.append(embed(x[:, :, lvl])) # [B T D]
|
347 |
+
|
348 |
+
x = torch.stack(batch, dim=2) # [B T C D]
|
349 |
+
x, acoustic_token_mask, mask_ratio, start_t = self.process_input(
|
350 |
+
x, lengths, code_level, min_seq_length=min_seq_length,
|
351 |
+
mask_ratio=mask_ratio, start_t=start_t,
|
352 |
+
acoustic_token_mask=acoustic_token_mask
|
353 |
+
)
|
354 |
+
|
355 |
+
# Add phoneme embeddings
|
356 |
+
# Cross attention for all tokens?
|
357 |
+
|
358 |
+
# Add semantic tokens
|
359 |
+
# HACK ME
|
360 |
+
semantic_emb = self.semantic_embedding(semantic_tokens)
|
361 |
+
x = torch.add(x, semantic_emb)
|
362 |
+
# FIXME pfb30
|
363 |
+
|
364 |
+
# Merge different modalities
|
365 |
+
if self.hp.use_spkr_emb:
|
366 |
+
spkr_emb = F.normalize(speaker_emb, dim=-1)
|
367 |
+
spkr_emb = self.spkr_linear(
|
368 |
+
F.dropout(spkr_emb, self.hp.speaker_embed_dropout)
|
369 |
+
)
|
370 |
+
x = torch.add(x, spkr_emb)
|
371 |
+
|
372 |
+
output_frames = self.conformer(x, None)
|
373 |
+
|
374 |
+
x = self.heads[code_level](output_frames)
|
375 |
+
|
376 |
+
return x, acoustic_token_mask, mask_ratio, start_t
|
377 |
+
|
378 |
+
@torch.no_grad()
|
379 |
+
def inference(
|
380 |
+
self, codes, semantic_tokens,
|
381 |
+
length: torch.LongTensor, rvq_levels=7,
|
382 |
+
mask_ratio=0.99, maskgit_inference=True,
|
383 |
+
start_t: Union[torch.LongTensor, None] = None,
|
384 |
+
speaker_emb=None, steps=16
|
385 |
+
):
|
386 |
+
# Use half of the recording for the conditioning
|
387 |
+
if start_t is None:
|
388 |
+
start_t = torch.tensor(int((codes.shape[1]) / 2)).long()
|
389 |
+
|
390 |
+
start_t = start_t.item()
|
391 |
+
|
392 |
+
for rvq_level in range(rvq_levels):
|
393 |
+
original_codes = torch.clone(codes)
|
394 |
+
if rvq_level == 0 and maskgit_inference:
|
395 |
+
codes = self.multi_step_inference(
|
396 |
+
original_codes, semantic_tokens, length,
|
397 |
+
start_t=start_t, vamp_filtering=False,
|
398 |
+
speaker_emb=speaker_emb, steps=16
|
399 |
+
)
|
400 |
+
else:
|
401 |
+
codes = self.one_step_inference(
|
402 |
+
original_codes, semantic_tokens, length,
|
403 |
+
code_level=rvq_level,
|
404 |
+
mask_ratio=mask_ratio, start_t=start_t,
|
405 |
+
speaker_emb=speaker_emb
|
406 |
+
)
|
407 |
+
|
408 |
+
codes = rearrange(codes, 'T C -> 1 T C')
|
409 |
+
|
410 |
+
# Remove any padding left
|
411 |
+
codes = rearrange(codes, '1 T C -> 1 C T')
|
412 |
+
codes = torch.where(codes >= self.hp.n_codes, 0, codes)
|
413 |
+
acoustic_tokens = codes
|
414 |
+
semantic_tokens = rearrange(semantic_tokens, 'b c -> b 1 c')
|
415 |
+
semantic_tokens = torch.where(
|
416 |
+
semantic_tokens >= self.hp.n_codes, 0, semantic_tokens)
|
417 |
+
codes = torch.cat([semantic_tokens, acoustic_tokens], dim=1)
|
418 |
+
|
419 |
+
return codes
|
420 |
+
|
421 |
+
@torch.no_grad()
|
422 |
+
def one_step_inference(
|
423 |
+
self, original_codes, semantic_tokens, lengths, code_level=0,
|
424 |
+
mask_ratio=0.99, start_t=0, inference_setup="argmax", speaker_emb=None
|
425 |
+
):
|
426 |
+
codes = torch.clone(original_codes)
|
427 |
+
logits, _, _, _ = self.forward(
|
428 |
+
codes, code_level, semantic_tokens, lengths,
|
429 |
+
mask_ratio=mask_ratio, start_t=start_t,
|
430 |
+
speaker_emb=speaker_emb, acoustic_token_mask=False)
|
431 |
+
|
432 |
+
if inference_setup == "argmax":
|
433 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
434 |
+
top_indeces = torch.argmax(probs, dim=-1)
|
435 |
+
|
436 |
+
if inference_setup == "sampling":
|
437 |
+
top_indeces = torch.distributions.Categorical(
|
438 |
+
logits=logits).sample()
|
439 |
+
|
440 |
+
codes = rearrange(codes, '1 T C -> T C')
|
441 |
+
codes[start_t:, code_level] = top_indeces[0, start_t:]
|
442 |
+
|
443 |
+
return codes
|
444 |
+
|
445 |
+
@torch.no_grad()
|
446 |
+
def multi_step_inference(
|
447 |
+
self, original_codes, semantic_tokens, lengths,
|
448 |
+
start_t: torch.LongTensor=None,
|
449 |
+
choice_temperature=1.0, start_iter=0,
|
450 |
+
steps=16, vamp_filtering=False, speaker_emb=None
|
451 |
+
):
|
452 |
+
codes = torch.clone(original_codes)
|
453 |
+
code_level = 0
|
454 |
+
_, seq_len, _ = original_codes.shape
|
455 |
+
mask_token_id = self.padding_id
|
456 |
+
|
457 |
+
# Get true codes for the prompt
|
458 |
+
prompt_mask = codes[:, :start_t, code_level]
|
459 |
+
|
460 |
+
# Fill up rest with masks
|
461 |
+
mask = torch.full(
|
462 |
+
(1, seq_len - start_t), mask_token_id, device=self.device)
|
463 |
+
inputs = torch.cat((prompt_mask, mask), 1)
|
464 |
+
|
465 |
+
num_mask_tokens_at_start = torch.sum(inputs == mask_token_id, axis=-1)
|
466 |
+
|
467 |
+
# Initializes state
|
468 |
+
state = state_init(inputs, steps, start_iter=start_iter)
|
469 |
+
|
470 |
+
def loop_cond_fn(state):
|
471 |
+
"""Beam search loop termination condition."""
|
472 |
+
not_at_end = (state.cur_index < steps)
|
473 |
+
return not_at_end
|
474 |
+
|
475 |
+
while loop_cond_fn(state):
|
476 |
+
"""Beam search loop state update function."""
|
477 |
+
step = state.cur_index
|
478 |
+
# Current input ids: [batch_size, seq_length].
|
479 |
+
cur_ids = state.cur_seqs
|
480 |
+
|
481 |
+
# Calls model on current seqs to get next-iteration seqs.
|
482 |
+
with torch.no_grad():
|
483 |
+
logits, _, _, _ = self.forward(
|
484 |
+
rearrange(inputs, 'B T -> B T 1'),
|
485 |
+
code_level,
|
486 |
+
semantic_tokens, lengths,
|
487 |
+
acoustic_token_mask=False,
|
488 |
+
speaker_emb=speaker_emb)
|
489 |
+
|
490 |
+
# Samples the ids using categorical sampling:
|
491 |
+
if vamp_filtering:
|
492 |
+
typical_mass = 0.2
|
493 |
+
typical_min_tokens = 1
|
494 |
+
top_p = None
|
495 |
+
sample_cutoff = 0.5
|
496 |
+
typical_filtering = False
|
497 |
+
sampled_ids, selected_probs = sample_from_logits(
|
498 |
+
logits, sample=((step / steps) <= sample_cutoff),
|
499 |
+
temperature=choice_temperature,
|
500 |
+
typical_filtering=typical_filtering,
|
501 |
+
typical_mass=typical_mass,
|
502 |
+
typical_min_tokens=typical_min_tokens,
|
503 |
+
top_k=None, top_p=top_p, return_probs=True,
|
504 |
+
)
|
505 |
+
else:
|
506 |
+
sampled_ids = torch.distributions.Categorical(
|
507 |
+
logits=logits).sample()
|
508 |
+
|
509 |
+
# Just updates the masked tokens.
|
510 |
+
unknown_map = (cur_ids == mask_token_id)
|
511 |
+
sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)
|
512 |
+
# Defines the mask ratio for the next round. The number to mask out
|
513 |
+
# is determined by mask_ratio * unknown_number_in_the_beginning.
|
514 |
+
ratio = 1. * (step + 1) / steps
|
515 |
+
mask_ratio = masking_logic.schedule(ratio)
|
516 |
+
|
517 |
+
# Updates final seqs with the current sampled_ids.
|
518 |
+
final_seqs = torch.clone(state.final_seqs)
|
519 |
+
final_seqs[:, step, :] = sampled_ids
|
520 |
+
# Computes the probabilities of each selected tokens.
|
521 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
522 |
+
# Extract the probabilities of sampled ids
|
523 |
+
selected_probs = torch.squeeze(
|
524 |
+
torch.take_along_dim(
|
525 |
+
probs, torch.unsqueeze(sampled_ids, -1) , -1),
|
526 |
+
-1
|
527 |
+
)
|
528 |
+
|
529 |
+
# Ignores the tokens given in the input
|
530 |
+
# by overwriting their confidence.
|
531 |
+
selected_probs = torch.where(
|
532 |
+
unknown_map, selected_probs, torch.inf)
|
533 |
+
# Gets mask lens for each sample in the
|
534 |
+
# batch according to the mask ratio.
|
535 |
+
num_to_mask = torch.unsqueeze(
|
536 |
+
torch.floor(num_mask_tokens_at_start * mask_ratio), 1)
|
537 |
+
|
538 |
+
# Keeps at least one of prediction in this
|
539 |
+
# round and also masks out at least
|
540 |
+
# one and for the next iteration
|
541 |
+
num_to_mask = torch.maximum(
|
542 |
+
torch.tensor(1),
|
543 |
+
torch.minimum(
|
544 |
+
torch.sum(unknown_map, dim=-1, keepdim=True) - 1,
|
545 |
+
num_to_mask)
|
546 |
+
)
|
547 |
+
# Adds noise for randomness
|
548 |
+
masking = mask_by_random_topk(
|
549 |
+
num_to_mask, selected_probs, choice_temperature * (1. - ratio))
|
550 |
+
# Masks tokens with lower confidence.
|
551 |
+
sampled_ids = torch.where(masking, mask_token_id, sampled_ids)
|
552 |
+
|
553 |
+
state = State(
|
554 |
+
cur_index=state.cur_index + 1,
|
555 |
+
cur_seqs=sampled_ids,
|
556 |
+
final_seqs=final_seqs
|
557 |
+
)
|
558 |
+
|
559 |
+
codes = torch.clone(original_codes)
|
560 |
+
codes = rearrange(codes, '1 T C -> T C')
|
561 |
+
codes[:, 0] = state.final_seqs[0][-1]
|
562 |
+
|
563 |
+
return codes
|
modules/speech_tokenizer.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Speech tokenizer class.
|
2 |
+
|
3 |
+
Copyright PolyAI Limited.
|
4 |
+
"""
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torchaudio
|
11 |
+
from speechtokenizer import SpeechTokenizer as ST
|
12 |
+
|
13 |
+
from modules.tokenizer import BaseTokenizer
|
14 |
+
|
15 |
+
|
16 |
+
class SpeechTokenizer(BaseTokenizer):
|
17 |
+
def __init__(self, config_path: str, ckpt_path: str):
|
18 |
+
self.device = torch.device(
|
19 |
+
"cuda" if torch.cuda.is_available() else "cpu")
|
20 |
+
self.model = ST.load_from_checkpoint(
|
21 |
+
config_path, ckpt_path).to(self.device)
|
22 |
+
self.model.eval()
|
23 |
+
|
24 |
+
def encode_file(
|
25 |
+
self, folder_path: str, destination_folder: str, filename: str):
|
26 |
+
dest_path = os.path.join(
|
27 |
+
destination_folder, "semantic",
|
28 |
+
os.path.splitext(filename)[0] + ".npy"
|
29 |
+
)
|
30 |
+
dest_path2 = os.path.join(
|
31 |
+
destination_folder, "acoustic",
|
32 |
+
os.path.splitext(filename)[0] + ".npy"
|
33 |
+
)
|
34 |
+
if os.path.exists(dest_path) and os.path.exists(dest_path2):
|
35 |
+
pass
|
36 |
+
else:
|
37 |
+
self._create_subfolders(destination_folder=destination_folder)
|
38 |
+
|
39 |
+
file_path = os.path.join(folder_path, filename)
|
40 |
+
wav_info = torchaudio.info(file_path)
|
41 |
+
wav_dur_sec = wav_info.num_frames / wav_info.sample_rate
|
42 |
+
if wav_dur_sec > 60:
|
43 |
+
logging.info(
|
44 |
+
f"Skipping {file_path} is too long: {wav_dur_sec:.3f} sec,"
|
45 |
+
"can cause CUDA OOM"
|
46 |
+
)
|
47 |
+
return
|
48 |
+
wav, sr = torchaudio.load(file_path)
|
49 |
+
if sr != self.model.sample_rate:
|
50 |
+
logging.warning(
|
51 |
+
"Wav sample rate %(wav_sr)s does not match the model"
|
52 |
+
"sampling rate %(model_sr)s. Resampling audio",
|
53 |
+
{"wav_sr": sr, "model_sr": self.model.sample_rate},
|
54 |
+
)
|
55 |
+
wav = torchaudio.functional.resample(
|
56 |
+
wav, sr, self.model.sample_rate)
|
57 |
+
wav = wav.unsqueeze(0)
|
58 |
+
wav = wav.to(self.device)
|
59 |
+
|
60 |
+
# Extract discrete codes from SpeechTokenizer
|
61 |
+
with torch.no_grad():
|
62 |
+
codes = self.model.encode(wav) # codes: (n_q, B, T)
|
63 |
+
|
64 |
+
semantic_tokens = codes[0, 0, :]
|
65 |
+
acoustic_tokens = codes[1:, 0, :]
|
66 |
+
|
67 |
+
# Save the encoding as .npy
|
68 |
+
dest_path = os.path.join(
|
69 |
+
destination_folder, "acoustic",
|
70 |
+
os.path.splitext(filename)[0] + ".npy"
|
71 |
+
)
|
72 |
+
np.save(dest_path, acoustic_tokens.cpu().numpy())
|
73 |
+
|
74 |
+
dest_path = os.path.join(
|
75 |
+
destination_folder, "semantic",
|
76 |
+
os.path.splitext(filename)[0] + ".npy"
|
77 |
+
)
|
78 |
+
np.save(dest_path, semantic_tokens.cpu().numpy())
|
79 |
+
|
80 |
+
@staticmethod
|
81 |
+
def _create_subfolders(destination_folder: str):
|
82 |
+
if not os.path.exists(destination_folder + "/acoustic"):
|
83 |
+
os.makedirs(destination_folder + "/acoustic")
|
84 |
+
|
85 |
+
if not os.path.exists(destination_folder + "/semantic"):
|
86 |
+
os.makedirs(destination_folder + "/semantic")
|
modules/t2s_model.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""T2S model definition.
|
2 |
+
|
3 |
+
Copyright PolyAI Limited.
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from torch import nn
|
9 |
+
from transformers import EvalPrediction, T5Config, T5ForConditionalGeneration
|
10 |
+
|
11 |
+
from data.collation import get_text_semantic_token_collater
|
12 |
+
|
13 |
+
|
14 |
+
def compute_custom_metrics(eval_prediction: EvalPrediction):
|
15 |
+
# eval_prediction: tuple
|
16 |
+
# eval_prediction[0]: tensor of decoder outputs(logits) (n_batch, n_semantic, n_tokens) # noqa
|
17 |
+
# eval_prediction[1]: tensor of encoder outputs (n_batch, n_text/n_phone, n_hidden) # noqa
|
18 |
+
logits = eval_prediction.predictions[0]
|
19 |
+
labels = eval_prediction.label_ids
|
20 |
+
n_vocab = logits.shape[-1]
|
21 |
+
mask = labels == -100
|
22 |
+
top_1 = np.argmax(logits, axis=-1) == labels
|
23 |
+
top_1[mask] = False
|
24 |
+
top_5 = np.argsort(logits, axis=-1)[:, :, -5:]
|
25 |
+
top_5 = np.any(top_5 == np.expand_dims(labels, axis=-1), axis=-1)
|
26 |
+
top_5[mask] = False
|
27 |
+
|
28 |
+
top_10 = np.argsort(logits, axis=-1)[:, :, -10:]
|
29 |
+
top_10 = np.any(top_10 == np.expand_dims(labels, axis=-1), axis=-1)
|
30 |
+
top_10[mask] = False
|
31 |
+
|
32 |
+
top_1_accuracy = np.sum(top_1) / np.sum(~mask)
|
33 |
+
top_5_accuracy = np.sum(top_5) / np.sum(~mask)
|
34 |
+
top_10_accuracy = np.sum(top_10) / np.sum(~mask)
|
35 |
+
|
36 |
+
return {
|
37 |
+
"top_1_accuracy": top_1_accuracy,
|
38 |
+
"top_5_accuracy": top_5_accuracy,
|
39 |
+
"top_10_accuracy": top_10_accuracy,
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
class T2S(nn.Module):
|
44 |
+
def __init__(self, hp):
|
45 |
+
super().__init__()
|
46 |
+
self.text_tokens_file = "ckpt/unique_text_tokens.k2symbols"
|
47 |
+
self.collater = get_text_semantic_token_collater(self.text_tokens_file)
|
48 |
+
self.model_size = hp.model_size
|
49 |
+
self.vocab_size = len(self.collater.idx2token)
|
50 |
+
self.config = self._define_model_config(self.model_size)
|
51 |
+
|
52 |
+
print(f"{self.config = }")
|
53 |
+
self.t2s = T5ForConditionalGeneration(self.config)
|
54 |
+
|
55 |
+
def _define_model_config(self, model_size):
|
56 |
+
if model_size == "test":
|
57 |
+
# n_params = 16M
|
58 |
+
d_ff = 16
|
59 |
+
d_model = 8
|
60 |
+
d_kv = 32
|
61 |
+
num_heads = 1
|
62 |
+
num_decoder_layers = 1
|
63 |
+
num_layers = 1
|
64 |
+
elif model_size == "tiny":
|
65 |
+
# n_params = 16M
|
66 |
+
d_ff = 1024
|
67 |
+
d_model = 256
|
68 |
+
d_kv = 32
|
69 |
+
num_heads = 4
|
70 |
+
num_decoder_layers = 4
|
71 |
+
num_layers = 4
|
72 |
+
elif model_size == "t5small":
|
73 |
+
# n_params = 60M
|
74 |
+
d_ff = 2048
|
75 |
+
d_model = 512
|
76 |
+
d_kv = 64
|
77 |
+
num_heads = 8
|
78 |
+
num_decoder_layers = 6
|
79 |
+
num_layers = 6
|
80 |
+
elif model_size == "large":
|
81 |
+
# n_params = 100M
|
82 |
+
d_ff = 2048
|
83 |
+
d_model = 512
|
84 |
+
d_kv = 64
|
85 |
+
num_heads = 8
|
86 |
+
num_decoder_layers = 14
|
87 |
+
num_layers = 14
|
88 |
+
elif model_size == "Large":
|
89 |
+
# n_params = 114M
|
90 |
+
d_ff = 4096
|
91 |
+
d_model = 512
|
92 |
+
d_kv = 64
|
93 |
+
num_heads = 8
|
94 |
+
num_decoder_layers = 6
|
95 |
+
num_layers = 10
|
96 |
+
else:
|
97 |
+
raise ValueError(f"unknown {model_size}")
|
98 |
+
|
99 |
+
config = T5Config(
|
100 |
+
d_ff=d_ff,
|
101 |
+
d_model=d_model,
|
102 |
+
d_kv=d_kv,
|
103 |
+
num_heads=num_heads,
|
104 |
+
num_decoder_layers=num_decoder_layers,
|
105 |
+
num_layers=num_layers,
|
106 |
+
decoder_start_token_id=0,
|
107 |
+
eos_token_id=2,
|
108 |
+
vocab_size=self.vocab_size,
|
109 |
+
)
|
110 |
+
|
111 |
+
return config
|
modules/tokenizer.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Base tokenizer class.
|
2 |
+
|
3 |
+
Copyright PolyAI Limited.
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
from asyncio import as_completed
|
7 |
+
from concurrent.futures import ThreadPoolExecutor
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from utils import measure_duration
|
12 |
+
|
13 |
+
|
14 |
+
class BaseTokenizer:
|
15 |
+
@measure_duration
|
16 |
+
def encode_files_with_model_seq(
|
17 |
+
self, folder_path: str, destination_folder: str):
|
18 |
+
# Ensure destination folder exists
|
19 |
+
if not os.path.exists(destination_folder):
|
20 |
+
os.makedirs(destination_folder)
|
21 |
+
|
22 |
+
# Go through each file in the folder
|
23 |
+
filenames = os.listdir(folder_path)
|
24 |
+
# encoding files has no side effects
|
25 |
+
for filename in tqdm(filenames):
|
26 |
+
self.encode_file(
|
27 |
+
folder_path=folder_path,
|
28 |
+
destination_folder=destination_folder,
|
29 |
+
filename=filename,
|
30 |
+
)
|
31 |
+
|
32 |
+
def get_chunk(self, folder_path, start_percent=0, end_percent=100):
|
33 |
+
filenames = os.listdir(folder_path)
|
34 |
+
total_files = len(filenames)
|
35 |
+
|
36 |
+
start_idx = int(total_files * (start_percent / 100))
|
37 |
+
end_idx = int(total_files * (end_percent / 100))
|
38 |
+
|
39 |
+
return filenames[start_idx:end_idx]
|
40 |
+
|
41 |
+
@measure_duration
|
42 |
+
def encode_files_with_model_concurrent(
|
43 |
+
self, folder_path: str, destination_folder: str, start_percent: int,
|
44 |
+
end_percent: int,
|
45 |
+
):
|
46 |
+
# Ensure destination folder exists
|
47 |
+
if not os.path.exists(destination_folder):
|
48 |
+
os.makedirs(destination_folder)
|
49 |
+
|
50 |
+
# Go through each file in the folder
|
51 |
+
filenames = self.get_chunk(folder_path, start_percent, end_percent)
|
52 |
+
|
53 |
+
# encoding files has no side effects
|
54 |
+
with ThreadPoolExecutor(max_workers=40) as executor:
|
55 |
+
futures = [
|
56 |
+
executor.submit(
|
57 |
+
self.encode_file,
|
58 |
+
folder_path=folder_path,
|
59 |
+
destination_folder=destination_folder,
|
60 |
+
filename=filename,
|
61 |
+
)
|
62 |
+
for filename in filenames
|
63 |
+
]
|
64 |
+
# Wait for all tasks to complete
|
65 |
+
for future in as_completed(futures):
|
66 |
+
future.result()
|
67 |
+
|
68 |
+
# Explicitly shut down the thread pool
|
69 |
+
executor.shutdown()
|
70 |
+
|
71 |
+
def encode_file(
|
72 |
+
self, folder_path: str, destination_folder: str, filename: str):
|
73 |
+
raise NotImplementedError
|
modules/vocoder.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Vocoder wrapper.
|
2 |
+
|
3 |
+
Copyright PolyAI Limited.
|
4 |
+
"""
|
5 |
+
import enum
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import soundfile as sf
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from speechtokenizer import SpeechTokenizer
|
12 |
+
|
13 |
+
|
14 |
+
class VocoderType(enum.Enum):
|
15 |
+
SPEECHTOKENIZER = ("SPEECHTOKENIZER", 320)
|
16 |
+
|
17 |
+
def __init__(self, name, compression_ratio):
|
18 |
+
self._name_ = name
|
19 |
+
self.compression_ratio = compression_ratio
|
20 |
+
|
21 |
+
def get_vocoder(self, ckpt_path, config_path, **kwargs):
|
22 |
+
if self.name == "SPEECHTOKENIZER":
|
23 |
+
if ckpt_path:
|
24 |
+
vocoder = STWrapper(ckpt_path, config_path)
|
25 |
+
else:
|
26 |
+
vocoder = STWrapper()
|
27 |
+
else:
|
28 |
+
raise ValueError(f"Unknown vocoder type {self.name}")
|
29 |
+
return vocoder
|
30 |
+
|
31 |
+
|
32 |
+
class STWrapper(nn.Module):
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
ckpt_path: str = './ckpt/speechtokenizer/SpeechTokenizer.pt',
|
36 |
+
config_path = './ckpt/speechtokenizer/config.json',
|
37 |
+
):
|
38 |
+
super().__init__()
|
39 |
+
self.model = SpeechTokenizer.load_from_checkpoint(
|
40 |
+
config_path, ckpt_path)
|
41 |
+
|
42 |
+
def eval(self):
|
43 |
+
self.model.eval()
|
44 |
+
|
45 |
+
@torch.no_grad()
|
46 |
+
def decode(self, codes: torch.Tensor, verbose: bool = False):
|
47 |
+
original_device = codes.device
|
48 |
+
|
49 |
+
codes = codes.to(self.device)
|
50 |
+
audio_array = self.model.decode(codes)
|
51 |
+
|
52 |
+
return audio_array.to(original_device)
|
53 |
+
|
54 |
+
def decode_to_file(self, codes_path, out_path) -> None:
|
55 |
+
codes = np.load(codes_path)
|
56 |
+
codes = torch.from_numpy(codes)
|
57 |
+
wav = self.decode(codes).cpu().numpy()
|
58 |
+
sf.write(out_path, wav, samplerate=self.model.sample_rate)
|
59 |
+
|
60 |
+
@torch.no_grad()
|
61 |
+
def encode(self, wav, verbose=False, n_quantizers: int = None):
|
62 |
+
original_device = wav.device
|
63 |
+
wav = wav.to(self.device)
|
64 |
+
codes = self.model.encode(wav) # codes: (n_q, B, T)
|
65 |
+
return codes.to(original_device)
|
66 |
+
|
67 |
+
def encode_to_file(self, wav_path, out_path) -> None:
|
68 |
+
wav, _ = sf.read(wav_path, dtype='float32')
|
69 |
+
wav = torch.from_numpy(wav).unsqueeze(0).unsqueeze(0)
|
70 |
+
codes = self.encode(wav).cpu().numpy()
|
71 |
+
np.save(out_path, codes)
|
72 |
+
|
73 |
+
def remove_weight_norm(self):
|
74 |
+
pass
|
75 |
+
|
76 |
+
@property
|
77 |
+
def device(self):
|
78 |
+
return next(self.model.parameters()).device
|
79 |
+
|
transformer_infer.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Inference logic.
|
2 |
+
|
3 |
+
Copyright PolyAI Limited.
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
import json
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import soundfile as sf
|
14 |
+
import torch
|
15 |
+
from einops import rearrange
|
16 |
+
from librosa.util import normalize
|
17 |
+
from pyannote.audio import Inference
|
18 |
+
from transformers import GenerationConfig, T5ForConditionalGeneration
|
19 |
+
|
20 |
+
import constants as c
|
21 |
+
from data.collation import get_text_semantic_token_collater
|
22 |
+
from data.semantic_dataset import TextTokenizer
|
23 |
+
from modules.s2a_model import Pheme
|
24 |
+
from modules.vocoder import VocoderType
|
25 |
+
|
26 |
+
# How many times one token can be generated
|
27 |
+
MAX_TOKEN_COUNT = 100
|
28 |
+
|
29 |
+
logging.basicConfig(level=logging.DEBUG)
|
30 |
+
device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
|
31 |
+
|
32 |
+
|
33 |
+
def parse_arguments():
|
34 |
+
parser = argparse.ArgumentParser()
|
35 |
+
parser.add_argument(
|
36 |
+
"--text", type=str,
|
37 |
+
default="I gotta say, I would never expect that to happen!"
|
38 |
+
)
|
39 |
+
parser.add_argument(
|
40 |
+
"--manifest_path", type=str, default="demo/manifest.json")
|
41 |
+
parser.add_argument("--outputdir", type=str, default="demo/")
|
42 |
+
parser.add_argument("--featuredir", type=str, default="demo/")
|
43 |
+
parser.add_argument(
|
44 |
+
"--text_tokens_file", type=str,
|
45 |
+
default="ckpt/unique_text_tokens.k2symbols"
|
46 |
+
)
|
47 |
+
parser.add_argument("--t2s_path", type=str, default="ckpt/t2s/")
|
48 |
+
parser.add_argument(
|
49 |
+
"--a2s_path", type=str, default="ckpt/s2a/s2a.ckpt")
|
50 |
+
|
51 |
+
parser.add_argument("--target_sample_rate", type=int, default=16_000)
|
52 |
+
|
53 |
+
parser.add_argument("--temperature", type=float, default=0.7)
|
54 |
+
parser.add_argument("--top_k", type=int, default=210)
|
55 |
+
parser.add_argument("--voice", type=str, default="male_voice")
|
56 |
+
|
57 |
+
return parser.parse_args()
|
58 |
+
|
59 |
+
|
60 |
+
class PhemeClient():
|
61 |
+
def __init__(self, args):
|
62 |
+
self.args = args
|
63 |
+
self.outputdir = args.outputdir
|
64 |
+
self.target_sample_rate = args.target_sample_rate
|
65 |
+
self.featuredir = Path(args.featuredir).expanduser()
|
66 |
+
self.collater = get_text_semantic_token_collater(args.text_tokens_file)
|
67 |
+
self.phonemizer = TextTokenizer()
|
68 |
+
|
69 |
+
self.load_manifest(args.manifest_path)
|
70 |
+
|
71 |
+
# T2S model
|
72 |
+
self.t2s = T5ForConditionalGeneration.from_pretrained(args.t2s_path)
|
73 |
+
self.t2s = T5ForConditionalGeneration.
|
74 |
+
self.t2s.to(device)
|
75 |
+
self.t2s.eval()
|
76 |
+
|
77 |
+
# S2A model
|
78 |
+
self.s2a = Pheme.load_from_checkpoint(args.a2s_path)
|
79 |
+
self.s2a.to(device=device)
|
80 |
+
self.s2a.eval()
|
81 |
+
|
82 |
+
# Vocoder
|
83 |
+
vocoder = VocoderType["SPEECHTOKENIZER"].get_vocoder(None, None)
|
84 |
+
self.vocoder = vocoder.to(device)
|
85 |
+
self.vocoder.eval()
|
86 |
+
|
87 |
+
self.spkr_embedding = Inference(
|
88 |
+
"pyannote/embedding",
|
89 |
+
window="whole",
|
90 |
+
use_auth_token=os.environ["HUGGING_FACE_HUB_TOKEN"],
|
91 |
+
)
|
92 |
+
|
93 |
+
def load_manifest(self, input_path):
|
94 |
+
input_file = {}
|
95 |
+
with open(input_path, "rb") as f:
|
96 |
+
for line in f:
|
97 |
+
temp = json.loads(line)
|
98 |
+
input_file[temp["audio_filepath"].split(".wav")[0]] = temp
|
99 |
+
self.input_file = input_file
|
100 |
+
|
101 |
+
def lazy_decode(self, decoder_output, symbol_table):
|
102 |
+
semantic_tokens = map(lambda x: symbol_table[x], decoder_output)
|
103 |
+
semantic_tokens = [int(x) for x in semantic_tokens if x.isdigit()]
|
104 |
+
|
105 |
+
return np.array(semantic_tokens)
|
106 |
+
|
107 |
+
def infer_text(self, text, voice, sampling_config):
|
108 |
+
semantic_prompt = np.load(self.args.featuredir + "/audios-speech-tokenizer/semantic/" + f"{voice}.npy") # noqa
|
109 |
+
phones_seq = self.phonemizer(text)[0]
|
110 |
+
input_ids = self.collater([phones_seq])
|
111 |
+
input_ids = input_ids.type(torch.IntTensor).to(device)
|
112 |
+
|
113 |
+
labels = [str(lbl) for lbl in semantic_prompt]
|
114 |
+
labels = self.collater([labels])[:, :-1]
|
115 |
+
decoder_input_ids = labels.to(device).long()
|
116 |
+
logging.debug(f"decoder_input_ids: {decoder_input_ids}")
|
117 |
+
|
118 |
+
counts = 1E10
|
119 |
+
while (counts > MAX_TOKEN_COUNT):
|
120 |
+
output_ids = self.t2s.generate(
|
121 |
+
input_ids, decoder_input_ids=decoder_input_ids,
|
122 |
+
generation_config=sampling_config).sequences
|
123 |
+
|
124 |
+
# check repetitiveness
|
125 |
+
_, counts = torch.unique_consecutive(output_ids, return_counts=True)
|
126 |
+
counts = max(counts).item()
|
127 |
+
|
128 |
+
output_semantic = self.lazy_decode(
|
129 |
+
output_ids[0], self.collater.idx2token)
|
130 |
+
|
131 |
+
# remove the prompt
|
132 |
+
return output_semantic[len(semantic_prompt):].reshape(1, -1)
|
133 |
+
|
134 |
+
def _load_speaker_emb(self, element_id_prompt):
|
135 |
+
wav, _ = sf.read(self.featuredir / element_id_prompt)
|
136 |
+
audio = normalize(wav) * 0.95
|
137 |
+
speaker_emb = self.spkr_embedding(
|
138 |
+
{
|
139 |
+
"waveform": torch.FloatTensor(audio).unsqueeze(0),
|
140 |
+
"sample_rate": self.target_sample_rate
|
141 |
+
}
|
142 |
+
).reshape(1, -1)
|
143 |
+
|
144 |
+
return speaker_emb
|
145 |
+
|
146 |
+
def _load_prompt(self, prompt_file_path):
|
147 |
+
element_id_prompt = Path(prompt_file_path).stem
|
148 |
+
acoustic_path_prompt = self.featuredir / "audios-speech-tokenizer/acoustic" / f"{element_id_prompt}.npy" # noqa
|
149 |
+
semantic_path_prompt = self.featuredir / "audios-speech-tokenizer/semantic" / f"{element_id_prompt}.npy" # noqa
|
150 |
+
|
151 |
+
acoustic_prompt = np.load(acoustic_path_prompt).squeeze().T
|
152 |
+
semantic_prompt = np.load(semantic_path_prompt)[None]
|
153 |
+
|
154 |
+
return acoustic_prompt, semantic_prompt
|
155 |
+
|
156 |
+
def infer_acoustic(self, output_semantic, prompt_file_path):
|
157 |
+
semantic_tokens = output_semantic.reshape(1, -1)
|
158 |
+
acoustic_tokens = np.full(
|
159 |
+
[semantic_tokens.shape[1], 7], fill_value=c.PAD)
|
160 |
+
|
161 |
+
acoustic_prompt, semantic_prompt = self._load_prompt(prompt_file_path) # noqa
|
162 |
+
|
163 |
+
# Prepend prompt
|
164 |
+
acoustic_tokens = np.concatenate(
|
165 |
+
[acoustic_prompt, acoustic_tokens], axis=0)
|
166 |
+
semantic_tokens = np.concatenate([
|
167 |
+
semantic_prompt, semantic_tokens], axis=1)
|
168 |
+
|
169 |
+
# Add speaker
|
170 |
+
acoustic_tokens = np.pad(
|
171 |
+
acoustic_tokens, [[1, 0], [0, 0]], constant_values=c.SPKR_1)
|
172 |
+
semantic_tokens = np.pad(
|
173 |
+
semantic_tokens, [[0,0], [1, 0]], constant_values=c.SPKR_1)
|
174 |
+
|
175 |
+
speaker_emb = None
|
176 |
+
if self.s2a.hp.use_spkr_emb:
|
177 |
+
speaker_emb = self._load_speaker_emb(prompt_file_path)
|
178 |
+
speaker_emb = np.repeat(
|
179 |
+
speaker_emb, semantic_tokens.shape[1], axis=0)
|
180 |
+
speaker_emb = torch.from_numpy(speaker_emb).to(device)
|
181 |
+
else:
|
182 |
+
speaker_emb = None
|
183 |
+
|
184 |
+
acoustic_tokens = torch.from_numpy(
|
185 |
+
acoustic_tokens).unsqueeze(0).to(device).long()
|
186 |
+
semantic_tokens = torch.from_numpy(semantic_tokens).to(device).long()
|
187 |
+
start_t = torch.tensor(
|
188 |
+
[acoustic_prompt.shape[0]], dtype=torch.long, device=device)
|
189 |
+
length = torch.tensor([
|
190 |
+
semantic_tokens.shape[1]], dtype=torch.long, device=device)
|
191 |
+
|
192 |
+
codes = self.s2a.model.inference(
|
193 |
+
acoustic_tokens,
|
194 |
+
semantic_tokens,
|
195 |
+
start_t=start_t,
|
196 |
+
length=length,
|
197 |
+
maskgit_inference=True,
|
198 |
+
speaker_emb=speaker_emb
|
199 |
+
)
|
200 |
+
|
201 |
+
# Remove the prompt
|
202 |
+
synth_codes = codes[:, :, start_t:]
|
203 |
+
synth_codes = rearrange(synth_codes, "b c t -> c b t")
|
204 |
+
|
205 |
+
return synth_codes
|
206 |
+
|
207 |
+
def generate_audio(self, text, voice, sampling_config, prompt_file_path):
|
208 |
+
start_time = time.time()
|
209 |
+
output_semantic = self.infer_text(
|
210 |
+
text, voice, sampling_config
|
211 |
+
)
|
212 |
+
logging.debug(f"semantic_tokens: {time.time() - start_time}")
|
213 |
+
|
214 |
+
start_time = time.time()
|
215 |
+
codes = self.infer_acoustic(output_semantic, prompt_file_path)
|
216 |
+
logging.debug(f"acoustic_tokens: {time.time() - start_time}")
|
217 |
+
|
218 |
+
start_time = time.time()
|
219 |
+
audio_array = self.vocoder.decode(codes)
|
220 |
+
audio_array = rearrange(audio_array, "1 1 T -> T").cpu().numpy()
|
221 |
+
logging.debug(f"vocoder time: {time.time() - start_time}")
|
222 |
+
|
223 |
+
return audio_array
|
224 |
+
|
225 |
+
@torch.no_grad()
|
226 |
+
def infer(
|
227 |
+
self, text, voice="male_voice", temperature=0.7,
|
228 |
+
top_k=210, max_new_tokens=750,
|
229 |
+
):
|
230 |
+
sampling_config = GenerationConfig.from_pretrained(
|
231 |
+
self.args.t2s_path,
|
232 |
+
top_k=top_k,
|
233 |
+
num_beams=1,
|
234 |
+
do_sample=True,
|
235 |
+
temperature=temperature,
|
236 |
+
num_return_sequences=1,
|
237 |
+
max_new_tokens=max_new_tokens,
|
238 |
+
return_dict_in_generate=True,
|
239 |
+
output_scores=True
|
240 |
+
)
|
241 |
+
|
242 |
+
voice_data = self.input_file[voice]
|
243 |
+
prompt_file_path = voice_data["audio_prompt_filepath"]
|
244 |
+
text = voice_data["text"] + " " + text
|
245 |
+
|
246 |
+
audio_array = self.generate_audio(
|
247 |
+
text, voice, sampling_config, prompt_file_path)
|
248 |
+
|
249 |
+
return audio_array
|
250 |
+
|
251 |
+
|
252 |
+
if __name__ == "__main__":
|
253 |
+
args = parse_arguments()
|
254 |
+
args.outputdir = Path(args.outputdir).expanduser()
|
255 |
+
args.outputdir.mkdir(parents=True, exist_ok=True)
|
256 |
+
args.manifest_path = Path(args.manifest_path).expanduser()
|
257 |
+
|
258 |
+
client = PhemeClient(args)
|
259 |
+
audio_array = client.infer(args.text, voice=args.voice)
|
260 |
+
sf.write(os.path.join(
|
261 |
+
args.outputdir, f"{args.voice}.wav"), audio_array,
|
262 |
+
args.target_sample_rate
|
263 |
+
)
|
utils/__init__.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Copyright PolyAI Limited."""
|
2 |
+
import logging
|
3 |
+
import pdb
|
4 |
+
import sys
|
5 |
+
import traceback
|
6 |
+
from functools import wraps
|
7 |
+
from time import time
|
8 |
+
from typing import List
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from .symbol_table import SymbolTable
|
13 |
+
|
14 |
+
|
15 |
+
def load_checkpoint(ckpt_path: str) -> dict:
|
16 |
+
"""
|
17 |
+
Loads checkpoint, while matching phone embedding size.
|
18 |
+
"""
|
19 |
+
state_dict: dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
20 |
+
new_state_dict = dict()
|
21 |
+
for p_name in state_dict.keys():
|
22 |
+
if p_name.startswith("vocoder"):
|
23 |
+
continue
|
24 |
+
|
25 |
+
new_state_dict[p_name] = state_dict[p_name]
|
26 |
+
|
27 |
+
return new_state_dict
|
28 |
+
|
29 |
+
|
30 |
+
def breakpoint_on_error(fn):
|
31 |
+
"""Creates a breakpoint on error
|
32 |
+
|
33 |
+
Use as a wrapper
|
34 |
+
|
35 |
+
Args:
|
36 |
+
fn: the function
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
inner function
|
40 |
+
"""
|
41 |
+
|
42 |
+
def inner(*args, **kwargs):
|
43 |
+
try:
|
44 |
+
return fn(*args, **kwargs)
|
45 |
+
except Exception:
|
46 |
+
"""Standard python way of creating a breakpoint on error"""
|
47 |
+
extype, value, tb = sys.exc_info()
|
48 |
+
print(f"extype={extype},\nvalue={value}")
|
49 |
+
traceback.print_exc()
|
50 |
+
pdb.post_mortem(tb)
|
51 |
+
|
52 |
+
return inner
|
53 |
+
|
54 |
+
|
55 |
+
def measure_duration(f):
|
56 |
+
@wraps(f)
|
57 |
+
def wrap(*args, **kw):
|
58 |
+
ts = time()
|
59 |
+
result = f(*args, **kw)
|
60 |
+
te = time()
|
61 |
+
logging.debug("func:%r took: %2.4f sec" % (f.__name__, te - ts))
|
62 |
+
return result
|
63 |
+
|
64 |
+
return wrap
|
65 |
+
|
66 |
+
|
67 |
+
def split_metapath(in_paths: List[str]):
|
68 |
+
other_paths = []
|
69 |
+
|
70 |
+
for itm_path in in_paths:
|
71 |
+
other_paths.append(itm_path)
|
72 |
+
|
73 |
+
return other_paths
|
utils/get_tokens_speech_tokenizer.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Get tokens using the SpeechTokenizer.
|
2 |
+
|
3 |
+
Apply SpeechTokenizer to extract acoustic and semantic tokens.
|
4 |
+
The tokens will be extracted to
|
5 |
+
encoding_output/acoustic and encoding_output/semantic.
|
6 |
+
|
7 |
+
python utils/get_tokens_speech_tokenizer.py \
|
8 |
+
--config_path ckpt/speechtokenizer/config.json \
|
9 |
+
--ckpt_path ckpt/speechtokenizer/SpeechTokenizer.pt \
|
10 |
+
--encoding_input datasets/example/audios \
|
11 |
+
--encoding_output datasets/example/audios-speech-tokenizer
|
12 |
+
|
13 |
+
Copyright PolyAI Limited.
|
14 |
+
"""
|
15 |
+
import argparse
|
16 |
+
import pathlib
|
17 |
+
|
18 |
+
from modules.speech_tokenizer import SpeechTokenizer
|
19 |
+
|
20 |
+
MQTTS_ROOT_PATH = str(pathlib.Path(__file__).parent.resolve())
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
parser = argparse.ArgumentParser()
|
24 |
+
parser.add_argument(
|
25 |
+
"--config_path",
|
26 |
+
type=str,
|
27 |
+
help="Path to the SpeechTokenizer config",
|
28 |
+
default=MQTTS_ROOT_PATH + "/ckpt/speechtokenizer/config.json",
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--ckpt_path",
|
32 |
+
type=str,
|
33 |
+
help="Path to the SpeechTokenizer checkpoint",
|
34 |
+
default=MQTTS_ROOT_PATH + "/ckpt/speechtokenizer/SpeechTokenizer.pt",
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--encoding_input",
|
38 |
+
type=str,
|
39 |
+
help="Path to the input folder for encoding",
|
40 |
+
default=MQTTS_ROOT_PATH + "/datasets/giga-training-data/audios",
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--encoding_output",
|
44 |
+
type=str,
|
45 |
+
help="Path where to save the encoded tokens",
|
46 |
+
default="/tmp/encoding_output",
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--start_percent",
|
50 |
+
type=int,
|
51 |
+
default=0,
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
"--end_percent",
|
55 |
+
type=int,
|
56 |
+
default=100,
|
57 |
+
)
|
58 |
+
|
59 |
+
args = parser.parse_args()
|
60 |
+
print("Parsed args")
|
61 |
+
print(args)
|
62 |
+
|
63 |
+
tokenizer = SpeechTokenizer(
|
64 |
+
config_path=args.config_path,
|
65 |
+
ckpt_path=args.ckpt_path,
|
66 |
+
)
|
67 |
+
tokenizer.encode_files_with_model_concurrent(
|
68 |
+
folder_path=args.encoding_input, destination_folder=args.encoding_output,
|
69 |
+
start_percent=args.start_percent, end_percent=args.end_percent
|
70 |
+
)
|
utils/symbol_table.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang)
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
14 |
+
|
15 |
+
Copyright PolyAI Limited.
|
16 |
+
"""
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from typing import Dict, Generic, List, Optional, TypeVar, Union
|
19 |
+
|
20 |
+
Symbol = TypeVar('Symbol')
|
21 |
+
|
22 |
+
|
23 |
+
# Disable __repr__ otherwise it could freeze e.g. Jupyter.
|
24 |
+
@dataclass(repr=False)
|
25 |
+
class SymbolTable(Generic[Symbol]):
|
26 |
+
'''SymbolTable that maps symbol IDs, found on the FSA arcs to
|
27 |
+
actual objects. These objects can be arbitrary Python objects
|
28 |
+
that can serve as keys in a dictionary (i.e. they need to be
|
29 |
+
hashable and immutable).
|
30 |
+
|
31 |
+
The SymbolTable can only be read to/written from disk if the
|
32 |
+
symbols are strings.
|
33 |
+
'''
|
34 |
+
_id2sym: Dict[int, Symbol] = field(default_factory=dict)
|
35 |
+
'''Map an integer to a symbol.
|
36 |
+
'''
|
37 |
+
|
38 |
+
_sym2id: Dict[Symbol, int] = field(default_factory=dict)
|
39 |
+
'''Map a symbol to an integer.
|
40 |
+
'''
|
41 |
+
|
42 |
+
_next_available_id: int = 1
|
43 |
+
'''A helper internal field that helps adding new symbols
|
44 |
+
to the table efficiently.
|
45 |
+
'''
|
46 |
+
|
47 |
+
eps: Symbol = '<eps>'
|
48 |
+
'''Null symbol, always mapped to index 0.
|
49 |
+
'''
|
50 |
+
|
51 |
+
def __post_init__(self):
|
52 |
+
for idx, sym in self._id2sym.items():
|
53 |
+
assert self._sym2id[sym] == idx
|
54 |
+
assert idx >= 0
|
55 |
+
|
56 |
+
for sym, idx in self._sym2id.items():
|
57 |
+
assert idx >= 0
|
58 |
+
assert self._id2sym[idx] == sym
|
59 |
+
|
60 |
+
if 0 not in self._id2sym:
|
61 |
+
self._id2sym[0] = self.eps
|
62 |
+
self._sym2id[self.eps] = 0
|
63 |
+
else:
|
64 |
+
assert self._id2sym[0] == self.eps
|
65 |
+
assert self._sym2id[self.eps] == 0
|
66 |
+
|
67 |
+
self._next_available_id = max(self._id2sym) + 1
|
68 |
+
|
69 |
+
@staticmethod
|
70 |
+
def from_str(s: str) -> 'SymbolTable':
|
71 |
+
'''Build a symbol table from a string.
|
72 |
+
|
73 |
+
The string consists of lines. Every line has two fields separated
|
74 |
+
by space(s), tab(s) or both. The first field is the symbol and the
|
75 |
+
second the integer id of the symbol.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
s:
|
79 |
+
The input string with the format described above.
|
80 |
+
Returns:
|
81 |
+
An instance of :class:`SymbolTable`.
|
82 |
+
'''
|
83 |
+
id2sym: Dict[int, str] = dict()
|
84 |
+
sym2id: Dict[str, int] = dict()
|
85 |
+
|
86 |
+
for line in s.split('\n'):
|
87 |
+
fields = line.split()
|
88 |
+
if len(fields) == 0:
|
89 |
+
continue # skip empty lines
|
90 |
+
assert len(fields) == 2, \
|
91 |
+
f'Expect a line with 2 fields. Given: {len(fields)}'
|
92 |
+
sym, idx = fields[0], int(fields[1])
|
93 |
+
assert sym not in sym2id, f'Duplicated symbol {sym}'
|
94 |
+
assert idx not in id2sym, f'Duplicated id {idx}'
|
95 |
+
id2sym[idx] = sym
|
96 |
+
sym2id[sym] = idx
|
97 |
+
|
98 |
+
eps = id2sym.get(0, '<eps>')
|
99 |
+
|
100 |
+
return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=eps)
|
101 |
+
|
102 |
+
@staticmethod
|
103 |
+
def from_file(filename: str) -> 'SymbolTable':
|
104 |
+
'''Build a symbol table from file.
|
105 |
+
|
106 |
+
Every line in the symbol table file has two fields separated by
|
107 |
+
space(s), tab(s) or both. The following is an example file:
|
108 |
+
|
109 |
+
.. code-block::
|
110 |
+
|
111 |
+
<eps> 0
|
112 |
+
a 1
|
113 |
+
b 2
|
114 |
+
c 3
|
115 |
+
|
116 |
+
Args:
|
117 |
+
filename:
|
118 |
+
Name of the symbol table file. Its format is documented above.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
An instance of :class:`SymbolTable`.
|
122 |
+
|
123 |
+
'''
|
124 |
+
with open(filename, 'r', encoding='utf-8') as f:
|
125 |
+
return SymbolTable.from_str(f.read().strip())
|
126 |
+
|
127 |
+
def to_str(self) -> str:
|
128 |
+
'''
|
129 |
+
Returns:
|
130 |
+
Return a string representation of this object. You can pass
|
131 |
+
it to the method ``from_str`` to recreate an identical object.
|
132 |
+
'''
|
133 |
+
s = ''
|
134 |
+
for idx, symbol in sorted(self._id2sym.items()):
|
135 |
+
s += f'{symbol} {idx}\n'
|
136 |
+
return s
|
137 |
+
|
138 |
+
def to_file(self, filename: str):
|
139 |
+
'''Serialize the SymbolTable to a file.
|
140 |
+
|
141 |
+
Every line in the symbol table file has two fields separated by
|
142 |
+
space(s), tab(s) or both. The following is an example file:
|
143 |
+
|
144 |
+
.. code-block::
|
145 |
+
|
146 |
+
<eps> 0
|
147 |
+
a 1
|
148 |
+
b 2
|
149 |
+
c 3
|
150 |
+
|
151 |
+
Args:
|
152 |
+
filename:
|
153 |
+
Name of the symbol table file. Its format is documented above.
|
154 |
+
'''
|
155 |
+
with open(filename, 'w') as f:
|
156 |
+
for idx, symbol in sorted(self._id2sym.items()):
|
157 |
+
print(symbol, idx, file=f)
|
158 |
+
|
159 |
+
def add(self, symbol: Symbol, index: Optional[int] = None) -> int:
|
160 |
+
'''Add a new symbol to the SymbolTable.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
symbol:
|
164 |
+
The symbol to be added.
|
165 |
+
index:
|
166 |
+
Optional int id to which the symbol should be assigned.
|
167 |
+
If it is not available, a ValueError will be raised.
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
The int id to which the symbol has been assigned.
|
171 |
+
'''
|
172 |
+
# Already in the table? Return its ID.
|
173 |
+
if symbol in self._sym2id:
|
174 |
+
return self._sym2id[symbol]
|
175 |
+
# Specific ID not provided - use next available.
|
176 |
+
if index is None:
|
177 |
+
index = self._next_available_id
|
178 |
+
# Specific ID provided but not available.
|
179 |
+
if index in self._id2sym:
|
180 |
+
raise ValueError(f"Cannot assign id '{index}' to '{symbol}' - "
|
181 |
+
f"already occupied by {self._id2sym[index]}")
|
182 |
+
self._sym2id[symbol] = index
|
183 |
+
self._id2sym[index] = symbol
|
184 |
+
|
185 |
+
# Update next available ID if needed
|
186 |
+
if self._next_available_id <= index:
|
187 |
+
self._next_available_id = index + 1
|
188 |
+
|
189 |
+
return index
|
190 |
+
|
191 |
+
def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]:
|
192 |
+
'''Get a symbol for an id or get an id for a symbol
|
193 |
+
|
194 |
+
Args:
|
195 |
+
k:
|
196 |
+
If it is an id, it tries to find the symbol corresponding
|
197 |
+
to the id; if it is a symbol, it tries to find the id
|
198 |
+
corresponding to the symbol.
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
An id or a symbol depending on the given `k`.
|
202 |
+
'''
|
203 |
+
if isinstance(k, int):
|
204 |
+
return self._id2sym[k]
|
205 |
+
else:
|
206 |
+
return self._sym2id[k]
|
207 |
+
|
208 |
+
def merge(self, other: 'SymbolTable') -> 'SymbolTable':
|
209 |
+
'''Create a union of two SymbolTables.
|
210 |
+
Raises an AssertionError if the same IDs are occupied by
|
211 |
+
different symbols.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
other:
|
215 |
+
A symbol table to merge with ``self``.
|
216 |
+
|
217 |
+
Returns:
|
218 |
+
A new symbol table.
|
219 |
+
'''
|
220 |
+
self._check_compatible(other)
|
221 |
+
|
222 |
+
id2sym = {**self._id2sym, **other._id2sym}
|
223 |
+
sym2id = {**self._sym2id, **other._sym2id}
|
224 |
+
|
225 |
+
return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=self.eps)
|
226 |
+
|
227 |
+
def _check_compatible(self, other: 'SymbolTable') -> None:
|
228 |
+
# Epsilon compatibility
|
229 |
+
assert self.eps == other.eps, f'Mismatched epsilon symbol: ' \
|
230 |
+
f'{self.eps} != {other.eps}'
|
231 |
+
# IDs compatibility
|
232 |
+
common_ids = set(self._id2sym).intersection(other._id2sym)
|
233 |
+
for idx in common_ids:
|
234 |
+
assert self[idx] == other[idx], f'ID conflict for id: {idx}, ' \
|
235 |
+
f'self[idx] = "{self[idx]}", ' \
|
236 |
+
f'other[idx] = "{other[idx]}"'
|
237 |
+
# Symbols compatibility
|
238 |
+
common_symbols = set(self._sym2id).intersection(other._sym2id)
|
239 |
+
for sym in common_symbols:
|
240 |
+
assert self[sym] == other[sym], f'ID conflict for id: {sym}, ' \
|
241 |
+
f'self[sym] = "{self[sym]}", ' \
|
242 |
+
f'other[sym] = "{other[sym]}"'
|
243 |
+
|
244 |
+
def __getitem__(self, item: Union[int, Symbol]) -> Union[Symbol, int]:
|
245 |
+
return self.get(item)
|
246 |
+
|
247 |
+
def __contains__(self, item: Union[int, Symbol]) -> bool:
|
248 |
+
if isinstance(item, int):
|
249 |
+
return item in self._id2sym
|
250 |
+
else:
|
251 |
+
return item in self._sym2id
|
252 |
+
|
253 |
+
def __len__(self) -> int:
|
254 |
+
return len(self._id2sym)
|
255 |
+
|
256 |
+
def __eq__(self, other: 'SymbolTable') -> bool:
|
257 |
+
if len(self) != len(other):
|
258 |
+
return False
|
259 |
+
|
260 |
+
for s in self.symbols:
|
261 |
+
if self[s] != other[s]:
|
262 |
+
return False
|
263 |
+
|
264 |
+
return True
|
265 |
+
|
266 |
+
@property
|
267 |
+
def ids(self) -> List[int]:
|
268 |
+
'''Returns a list of integer IDs corresponding to the symbols.
|
269 |
+
'''
|
270 |
+
ans = list(self._id2sym.keys())
|
271 |
+
ans.sort()
|
272 |
+
return ans
|
273 |
+
|
274 |
+
@property
|
275 |
+
def symbols(self) -> List[Symbol]:
|
276 |
+
'''Returns a list of symbols (e.g., strings) corresponding to
|
277 |
+
the integer IDs.
|
278 |
+
'''
|
279 |
+
ans = list(self._sym2id.keys())
|
280 |
+
ans.sort()
|
281 |
+
return ans
|