taras-sereda commited on
Commit
96ee597
1 Parent(s): 694ecc6

minimal set of files to run inference; pheme-small checkpoint

Browse files
ckpt/s2a/config.json ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "saving_path": "/home/ubuntu/experiments/a2s_giga2",
3
+ "resume_checkpoint": null,
4
+ "vocoder_type": "SPEECHTOKENIZER",
5
+ "vocoder_config_path": null,
6
+ "vocoder_ckpt_path": null,
7
+ "metapath": [
8
+ "/home/ubuntu/data/poly/giga-training-data/train.json"
9
+ ],
10
+ "val_metapath": [
11
+ "/home/ubuntu/data/poly/giga-training-data/dev.json"
12
+ ],
13
+ "pretrained_path": null,
14
+ "speaker_embedding_dir": null,
15
+ "sampledir": "/home/ubuntu/experiments/a2s_giga2",
16
+ "lr": 0.0005,
17
+ "batch_size": 400.0,
18
+ "train_bucket_size": 8192,
19
+ "training_step": 800000,
20
+ "optim_flat_percent": 0.0,
21
+ "warmup_step": 10000,
22
+ "adam_beta1": 0.9,
23
+ "adam_beta2": 0.98,
24
+ "ffd_size": 1024,
25
+ "hidden_size": 768,
26
+ "enc_nlayers": 3,
27
+ "dec_nlayers": 6,
28
+ "nheads": 8,
29
+ "dropout": 0.1,
30
+ "depthwise_conv_kernel_size": 5,
31
+ "aligner_softmax_temp": 1.0,
32
+ "layer_norm_eps": 1e-05,
33
+ "use_sem_tokens": true,
34
+ "use_spkr_emb": true,
35
+ "use_text_emb": false,
36
+ "fairseq": false,
37
+ "only_inference": false,
38
+ "speaker_embed_dropout": 0.05,
39
+ "label_smoothing": 0.0,
40
+ "val_check_interval": 1,
41
+ "max_dataset_samples": -1,
42
+ "check_val_every_n_epoch": 1,
43
+ "precision": "bf16",
44
+ "nworkers": 12,
45
+ "distributed": true,
46
+ "accelerator": "gpu",
47
+ "version": null,
48
+ "accumulate_grad_batches": 1,
49
+ "sagemaker": false,
50
+ "use_repetition_token": false,
51
+ "use_repetition_gating": false,
52
+ "repetition_penalty": 1.0,
53
+ "sampling_temperature": 1.0,
54
+ "top_k": -1,
55
+ "min_top_k": 3,
56
+ "top_p": 0.8,
57
+ "sample_num": 4,
58
+ "length_penalty_max_length": 150,
59
+ "length_penalty_max_prob": 0.95,
60
+ "max_input_length": 2048,
61
+ "max_output_length": 2000,
62
+ "phone_context_window": 3,
63
+ "sample_rate": 16000,
64
+ "n_codes": 1024,
65
+ "n_cluster_groups": 7,
66
+ "first_n_lvls": 7,
67
+ "use_pretrained_ckpt_cfg": false,
68
+ "n_semantic_codes": 1024
69
+ }
ckpt/s2a/s2a.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f083b82b7a2e902cf318163562f6b3faa2d2b4ce72d20b92e3de8b8dfc125383
3
+ size 671800067
ckpt/t2s/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "T5ForConditionalGeneration"
4
+ ],
5
+ "classifier_dropout": 0.0,
6
+ "d_ff": 2048,
7
+ "d_kv": 64,
8
+ "d_model": 512,
9
+ "decoder_start_token_id": 0,
10
+ "dense_act_fn": "relu",
11
+ "dropout_rate": 0.1,
12
+ "eos_token_id": 2,
13
+ "feed_forward_proj": "relu",
14
+ "initializer_factor": 1.0,
15
+ "is_encoder_decoder": true,
16
+ "is_gated_act": false,
17
+ "layer_norm_epsilon": 1e-06,
18
+ "model_type": "t5",
19
+ "num_decoder_layers": 6,
20
+ "num_heads": 8,
21
+ "num_layers": 6,
22
+ "pad_token_id": 0,
23
+ "relative_attention_max_distance": 128,
24
+ "relative_attention_num_buckets": 32,
25
+ "torch_dtype": "float32",
26
+ "transformers_version": "4.34.1",
27
+ "use_cache": true,
28
+ "vocab_size": 1119
29
+ }
ckpt/t2s/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "decoder_start_token_id": 0,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.34.1"
7
+ }
ckpt/t2s/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00aab5d35aa54a7cc7c81c92c706928e33d82a5c7e63b3628df48b0aed28606f
3
+ size 178565209
constants.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Constants file.
2
+
3
+ Copyright PolyAI Limited.
4
+ """
5
+ SPKR_EMB_SIZE = 512
6
+
7
+ PAD = 1024
8
+
9
+ SPKR_1 = 1025
10
+ SPKR_2 = 1026
11
+
12
+ BOS_TOKEN_ID = 0
13
+ PAD_TOKEN_ID = 0
14
+ EOS_TOKEN_ID = 2
data/collation.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Collators for T2S and S2A.
2
+
3
+ Copyright PolyAI Limited.
4
+ """
5
+ from pathlib import Path
6
+ from typing import List, Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from utils.symbol_table import SymbolTable
12
+
13
+
14
+ class GlobalCollater:
15
+ def __init__(self, n_codes, n_semantic_codes):
16
+ self.n_codes = n_codes
17
+ self.sem_mask_id = n_semantic_codes
18
+
19
+ def collate(self, batch):
20
+ output = {
21
+ 'speaker': [],
22
+ 'tts_quantize_input': [],
23
+ 'tts_quantize_output': [],
24
+ 'quantize_mask': [],
25
+ 'f_names': [],
26
+ 'semantic_tokens': [],
27
+ 'quantization_lengths': [],
28
+ }
29
+ # Get the max length of everything
30
+ max_len_q = 0
31
+ for _, q_s, q_e, _, _ in batch:
32
+ if len(q_s) > max_len_q:
33
+ max_len_q = len(q_s)
34
+
35
+ output['quantization_lengths'].append(len(q_s))
36
+
37
+ # Pad each element, create mask
38
+ for spkr, qs, qe, itm_name, s_tokens in batch:
39
+ # Deal with quantizations
40
+ q_mask = np.array(
41
+ [False] * len(qs) + [True] * (max_len_q - len(qs)))
42
+ qs = np.pad(
43
+ qs,
44
+ [[0, max_len_q-len(qs)], [0, 0]],
45
+ constant_values=self.n_codes
46
+ )
47
+ qe = np.pad(
48
+ qe,
49
+ [[0, max_len_q-len(qe)], [0, 0]],
50
+ constant_values=self.n_codes
51
+ )
52
+
53
+ # Deal with semantics
54
+ s_tokens = s_tokens.flatten()
55
+ s_tokens = np.pad(
56
+ s_tokens,
57
+ (0, max_len_q-len(s_tokens)),
58
+ constant_values=self.sem_mask_id
59
+ )
60
+
61
+ # Speaker padding
62
+ spkr = np.concatenate(
63
+ (spkr, np.zeros((max_len_q - len(spkr), 512))))
64
+
65
+ # Aggregate
66
+ output['speaker'].append(spkr)
67
+ output['tts_quantize_input'].append(qs)
68
+ output['tts_quantize_output'].append(qe)
69
+ output['quantize_mask'].append(q_mask)
70
+ output['f_names'].append(itm_name)
71
+ output["semantic_tokens"].append(s_tokens)
72
+
73
+ for k in output.keys():
74
+ if k == 'f_names':
75
+ continue
76
+ output[k] = np.array(output[k])
77
+ if 'mask' in k:
78
+ output[k] = torch.BoolTensor(output[k])
79
+ elif k in [
80
+ 'tts_quantize_input', 'tts_quantize_output',
81
+ 'semantic_tokens', 'quantization_lengths'
82
+ ]:
83
+ output[k] = torch.LongTensor(output[k])
84
+ else:
85
+ output[k] = torch.FloatTensor(output[k])
86
+ return output
87
+
88
+
89
+ class TextTokenCollater:
90
+ def __init__(
91
+ self,
92
+ text_tokens: List[str],
93
+ add_eos: bool = True,
94
+ add_bos: bool = True,
95
+ pad_symbol: str = "<pad>",
96
+ bos_symbol: str = "<bos>",
97
+ eos_symbol: str = "<eos>",
98
+ spkr_1_symbol: str = "spkr_1",
99
+ spkr_2_symbol: str = "spkr_2",
100
+ ):
101
+ self.pad_symbol = pad_symbol
102
+
103
+ self.add_eos = add_eos
104
+ self.add_bos = add_bos
105
+
106
+ self.bos_symbol = bos_symbol
107
+ self.eos_symbol = eos_symbol
108
+ self.spkr_1_symbol = spkr_1_symbol
109
+ self.spkr_2_symbol = spkr_2_symbol
110
+
111
+ unique_tokens = (
112
+ [pad_symbol]
113
+ + ([bos_symbol] if add_bos else [])
114
+ + ([eos_symbol] if add_eos else [])
115
+ + ([spkr_1_symbol])
116
+ + ([spkr_2_symbol])
117
+ + sorted(text_tokens)
118
+ )
119
+
120
+ self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
121
+ self.idx2token = [token for token in unique_tokens]
122
+
123
+ def __call__(
124
+ self, texts: List[str], texts_2: Union[None, List[str]] = None
125
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
126
+ tokens_seqs = [[p for p in text] for text in texts]
127
+
128
+ if texts_2 is None:
129
+ seqs = [
130
+ ([self.bos_symbol] if self.add_bos else [])
131
+ + [self.spkr_1_symbol]
132
+ + list(seq)
133
+ + ([self.eos_symbol] if self.add_eos else [])
134
+ for seq in tokens_seqs
135
+ ]
136
+ else:
137
+ tokens_seqs_2 = [[p for p in text] for text in texts_2]
138
+ seqs = [
139
+ ([self.bos_symbol] if self.add_bos else [])
140
+ + [self.spkr_1_symbol]
141
+ + list(seq)
142
+ + ([self.spkr_2_symbol])
143
+ + list(seq_2)
144
+ + ([self.eos_symbol] if self.add_eos else [])
145
+ for seq, seq_2 in zip(tokens_seqs, tokens_seqs_2)
146
+ ]
147
+
148
+ tokens_batch = torch.from_numpy(
149
+ np.array(
150
+ [[self.token2idx[token] for token in seq] for seq in seqs],
151
+ dtype=np.int64,
152
+ )
153
+ )
154
+
155
+ return tokens_batch
156
+
157
+
158
+ def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater:
159
+ text_tokens_path = Path(text_tokens_file)
160
+ unique_tokens = SymbolTable.from_file(text_tokens_path)
161
+ collater = TextTokenCollater(
162
+ unique_tokens.symbols, add_bos=True, add_eos=True
163
+ )
164
+ return collater
165
+
166
+
167
+ def get_text_semantic_token_collater(
168
+ text_tokens_file: str, n_semantic_tokens=1024) -> TextTokenCollater:
169
+ text_tokens_path = Path(text_tokens_file)
170
+ unique_tokens = SymbolTable.from_file(text_tokens_path)
171
+ for semantic_idx in range(n_semantic_tokens):
172
+ unique_tokens.add(str(semantic_idx))
173
+
174
+ collater = TextTokenCollater(
175
+ unique_tokens.symbols, add_bos=True, add_eos=True
176
+ )
177
+ return collater
178
+
179
+
180
+ if __name__ == '__main__':
181
+ text_tokens_file = 'ckpt/unique_text_tokens.k2symbols'
182
+ collater = get_text_semantic_token_collater(text_tokens_file)
data/data_module.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data module.
2
+
3
+ Copyright PolyAI Limited.
4
+ """
5
+ import typing
6
+ from pathlib import Path
7
+ from typing import List
8
+
9
+ import lightning.pytorch as pl
10
+ from torch.utils import data
11
+
12
+ from data.collation import GlobalCollater
13
+ from data.sampler import RandomBucketSampler
14
+ from data.single_speaker_dataset import QuantizeDataset
15
+ from utils import breakpoint_on_error
16
+
17
+
18
+ class ConcatDataset(data.ConcatDataset):
19
+ def __init__(self, datasets) -> None:
20
+ super().__init__(datasets)
21
+ self.lengths = []
22
+ for dataset in datasets:
23
+ self.lengths.extend(dataset.lengths)
24
+
25
+
26
+ class DataModule(pl.LightningDataModule):
27
+ def __init__(
28
+ self, hp, metapath: List[str], val_metapath: List[str],
29
+ world_size, local_rank
30
+ ):
31
+ super().__init__()
32
+ self.hp = hp
33
+ self.metapath = metapath
34
+ self.val_metapath = val_metapath
35
+ self.world_size = world_size
36
+ self.local_rank = local_rank
37
+ self.collater = GlobalCollater(
38
+ self.hp.n_codes, self.hp.n_semantic_codes)
39
+
40
+ def setup(self, stage: str) -> None:
41
+ if stage == "fit":
42
+ self.train_data = self.concatenate_datasets(
43
+ self.metapath, dataset_class=QuantizeDataset
44
+ )
45
+
46
+ if stage == "valid":
47
+ self.val_data = []
48
+ self.val_data_keys = []
49
+ self.prepare_val_datasets()
50
+ assert len(self.val_data) > 0
51
+ assert len(self.val_data_keys) > 0
52
+
53
+ @breakpoint_on_error
54
+ def concatenate_datasets(
55
+ self, metapaths, dataset_class: typing.Type[QuantizeDataset]):
56
+ data = []
57
+ for _, metapath in enumerate(metapaths):
58
+ metapath = Path(metapath)
59
+ # assumption that audios and audios-embeddings
60
+ # are in the same folder as metapath
61
+ datadir = metapath.with_name("audios")
62
+ assert datadir.exists()
63
+ data.append(
64
+ dataset_class(
65
+ self.hp,
66
+ metapath,
67
+ datadir=datadir,
68
+ speaker_embedding_dir=None,
69
+ )
70
+ )
71
+ return ConcatDataset(data)
72
+
73
+ def prepare_val_datasets(self):
74
+ for manifest in self.val_metapath:
75
+ self.val_data.append(
76
+ self.concatenate_datasets(
77
+ [manifest], dataset_class=QuantizeDataset)
78
+ )
79
+ name = Path(manifest).parent.name
80
+ self.val_data_keys.append(name)
81
+
82
+ assert len(self.val_data) == len(self.val_data_keys)
83
+
84
+ def train_dataloader(self):
85
+ length = self.train_data.lengths
86
+ sampler = RandomBucketSampler(
87
+ self.hp.train_bucket_size,
88
+ length,
89
+ self.hp.batch_size,
90
+ drop_last=True,
91
+ distributed=self.hp.distributed,
92
+ world_size=self.world_size,
93
+ rank=self.local_rank,
94
+ )
95
+ dataloader = data.DataLoader(
96
+ self.train_data,
97
+ num_workers=self.hp.nworkers,
98
+ batch_sampler=sampler,
99
+ collate_fn=self.collater.collate,
100
+ pin_memory=True
101
+ )
102
+
103
+ return dataloader
104
+
105
+ def val_dataloader(self):
106
+ val_loaders = []
107
+ for dataset in self.val_data:
108
+ val_loaders.append(
109
+ data.DataLoader(
110
+ dataset,
111
+ num_workers=self.hp.nworkers,
112
+ batch_size=int(self.hp.batch_size),
113
+ collate_fn=self.collater.collate,
114
+ shuffle=False,
115
+ pin_memory=True
116
+ )
117
+ )
118
+
119
+ return val_loaders
data/sampler.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Original sampling logic of MQTTS.
2
+
3
+ Copyright PolyAI Limited.
4
+ """
5
+ import math
6
+ import random
7
+
8
+ import numpy as np
9
+ from torch.utils import data
10
+
11
+
12
+ def StandardSampler(dataset, shuffle, distributed=False,
13
+ world_size=None, rank=None):
14
+ if distributed:
15
+ return data.distributed.DistributedSampler(
16
+ dataset, shuffle=shuffle, num_replicas=world_size, rank=rank)
17
+ if shuffle:
18
+ return data.RandomSampler(dataset)
19
+ return data.SequentialSampler(dataset)
20
+
21
+
22
+ def RandomBucketSampler(
23
+ nbuckets, length, batch_size, drop_last, distributed=False,
24
+ world_size=None, rank=None):
25
+ if distributed:
26
+ return DistributedRandomBucketSampler(
27
+ nbuckets, length, batch_size, drop_last, world_size, rank)
28
+ return SingleRandomBucketSampler(nbuckets, length, batch_size, drop_last)
29
+
30
+
31
+ class SingleRandomBucketSampler(data.Sampler):
32
+ def __init__(self, nbuckets, length, batch_size, drop_last):
33
+ self.length = length
34
+ self.batch_size = batch_size
35
+ self.drop_last = drop_last
36
+ indices = np.argsort([-x for x in length])
37
+ split = len(indices) // nbuckets
38
+ self.indices = []
39
+ for i in range(nbuckets):
40
+ self.indices.append(indices[i*split:(i+1)*split])
41
+ if nbuckets * split < len(length):
42
+ self.indices.append(indices[nbuckets*split:])
43
+
44
+ def __iter__(self):
45
+ random.shuffle(self.indices)
46
+ for x in self.indices:
47
+ random.shuffle(x)
48
+ idxs = [i for x in self.indices for i in x]
49
+ batches, batch, sum_len, max_len = [], [], 0, 0
50
+ for idx in idxs:
51
+ batch.append(idx)
52
+ sum_len += self.length[idx]
53
+ max_len = max(self.length[idx], max_len)
54
+ if max_len * len(batch) > self.batch_size:
55
+ batches.append(batch[:-1])
56
+ batch, sum_len, max_len = [batch[-1]], self.length[idx], self.length[idx] # noqa
57
+ if len(batch) > 0 and not self.drop_last:
58
+ batches.append(batch)
59
+ random.shuffle(batches)
60
+ return iter(batches)
61
+
62
+
63
+ class DistributedRandomBucketSampler(data.Sampler):
64
+ def __init__(self, nbuckets, length, batch_size,
65
+ drop_last, num_replicas, rank, seed=1234):
66
+ if rank >= num_replicas or rank < 0:
67
+ raise ValueError(
68
+ "Invalid rank {}, rank should be in the interval"
69
+ " [0, {}]".format(rank, num_replicas - 1))
70
+ indices = np.argsort(length)
71
+ split = len(indices) // nbuckets
72
+ self.length = length
73
+ self.batch_size = batch_size
74
+ self.drop_last = drop_last
75
+ self.indices = []
76
+ for i in range(nbuckets):
77
+ self.indices.append(indices[i*split:(i+1)*split])
78
+ if nbuckets * split < len(length):
79
+ self.indices.append(indices[nbuckets*split:])
80
+ self.num_replicas = num_replicas
81
+ self.rank = rank
82
+ self.epoch = 0
83
+ self.seed = seed
84
+
85
+ def __iter__(self):
86
+ # Deterministic shuffling
87
+ random.Random(self.epoch + self.seed).shuffle(self.indices)
88
+ for i, x in enumerate(self.indices):
89
+ seed = self.epoch + self.seed + i * 5
90
+ random.Random(seed).shuffle(x)
91
+ indices = [i for x in self.indices for i in x]
92
+
93
+ # Batching
94
+ batches, batch, sum_len, max_len = [], [], 0, 0
95
+ for idx in indices:
96
+ batch.append(idx)
97
+ sum_len += self.length[idx]
98
+ max_len = max(self.length[idx], max_len)
99
+ if max_len * len(batch) > self.batch_size:
100
+ batches.append(batch[:-1])
101
+ batch, sum_len, max_len = [batch[-1]], self.length[idx], self.length[idx] # noqa
102
+ # Subsample
103
+ num_samples = math.ceil(
104
+ (len(batches) - self.num_replicas) / self.num_replicas)
105
+ total_size = num_samples * self.num_replicas
106
+ batches = batches[:total_size]
107
+ batches = batches[self.rank*num_samples: (self.rank+1)*num_samples]
108
+ assert len(batches) == num_samples
109
+
110
+ # Stochastic suffling
111
+ random.shuffle(batches)
112
+ return iter(batches)
113
+
114
+ def set_epoch(self, epoch):
115
+ self.epoch = epoch
data/semantic_dataset.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Semantic tokens loading logic.
2
+
3
+ Copyright PolyAI Limited.
4
+ """
5
+ import json
6
+ import logging
7
+ import random
8
+ import re
9
+ from logging import getLogger
10
+ from pathlib import Path
11
+ from typing import List, Pattern, Union
12
+
13
+ import numpy as np
14
+ import torch
15
+ from phonemizer.backend import EspeakBackend
16
+ from phonemizer.backend.espeak.language_switch import LanguageSwitch
17
+ from phonemizer.backend.espeak.words_mismatch import WordMismatch
18
+ from phonemizer.punctuation import Punctuation
19
+ from phonemizer.separator import Separator
20
+ from torch.utils.data import DataLoader, Dataset
21
+ from tqdm import tqdm
22
+
23
+ from data.collation import get_text_semantic_token_collater
24
+
25
+
26
+ class TextTokenizer:
27
+ """Phonemize Text."""
28
+
29
+ def __init__(
30
+ self,
31
+ language="en-us",
32
+ backend="espeak",
33
+ separator=Separator(word="_", syllable="-", phone="|"),
34
+ preserve_punctuation=True,
35
+ punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
36
+ with_stress: bool = False,
37
+ tie: Union[bool, str] = False,
38
+ language_switch: LanguageSwitch = "keep-flags",
39
+ words_mismatch: WordMismatch = "ignore",
40
+ ) -> None:
41
+ logger = getLogger("phonemizer")
42
+ logger.setLevel(logging.ERROR)
43
+ if backend == "espeak":
44
+ phonemizer = EspeakBackend(
45
+ language,
46
+ punctuation_marks=punctuation_marks,
47
+ preserve_punctuation=preserve_punctuation,
48
+ with_stress=with_stress,
49
+ tie=tie,
50
+ language_switch=language_switch,
51
+ words_mismatch=words_mismatch,
52
+ logger=logger,
53
+ )
54
+ else:
55
+ raise NotImplementedError(f"{backend}")
56
+
57
+ self.backend = phonemizer
58
+ self.separator = separator
59
+
60
+ def to_list(self, phonemized: str) -> List[str]:
61
+ fields = []
62
+ for word in phonemized.split(self.separator.word):
63
+ # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
64
+ pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
65
+ fields.extend(
66
+ [p for p in pp if p != self.separator.phone] + [self.separator.word]
67
+ )
68
+ assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
69
+ self.separator.phone
70
+ )
71
+ return fields[:-1]
72
+
73
+ def __call__(self, text, strip=True) -> List[List[str]]:
74
+ if isinstance(text, str):
75
+ text = [text]
76
+
77
+ phonemized = self.backend.phonemize(
78
+ text, separator=self.separator, strip=strip, njobs=1
79
+ )
80
+ return [self.to_list(p) for p in phonemized]
81
+
82
+
83
+ class Collator:
84
+ def collate(self, batch):
85
+ input_ids = [item["input_ids"] for item in batch]
86
+ output_sequences = [item["labels"] for item in batch]
87
+
88
+ # Pad sequences to the maximum length in the batch
89
+ input_ids = torch.nn.utils.rnn.pad_sequence(
90
+ input_ids, batch_first=True, padding_value=0
91
+ )
92
+ output_sequences = torch.nn.utils.rnn.pad_sequence(
93
+ output_sequences, batch_first=True, padding_value=-100
94
+ )
95
+ # 1 - token is unmasked, 0 - token is masked.
96
+ attention_mask = input_ids != 0
97
+ return {
98
+ "input_ids": input_ids,
99
+ "attention_mask": attention_mask,
100
+ "labels": output_sequences,
101
+ }
102
+
103
+ class ConcatenateSemanticDataset(Dataset):
104
+ def __init__(
105
+ self, manifest_path: str, symbol_table_path: str,
106
+ n_samples: int = 0, max_duration=15):
107
+ self.data = []
108
+ self.phonemizer = TextTokenizer()
109
+ self.text_collater = get_text_semantic_token_collater(
110
+ symbol_table_path)
111
+ self.manifest_path = manifest_path
112
+ self.n_samples = n_samples
113
+ self.max_duration = max_duration
114
+ if manifest_path is not None:
115
+ self._build()
116
+
117
+ def __len__(self):
118
+ if self.n_samples:
119
+ return min(self.n_samples, len(self.data))
120
+ return len(self.data)
121
+
122
+ def remove_unknown_symbols(self, text: List[str]):
123
+ res = []
124
+ for sym in text:
125
+ if sym not in self.text_collater.token2idx:
126
+ # print(f'{sym} is unk')
127
+ continue
128
+ res.append(sym)
129
+ return res
130
+
131
+ def __getitem__(self, idx):
132
+ item = self.data[idx]
133
+
134
+ input_ids = item["phoneme"].split("|")
135
+ input_ids = self.remove_unknown_symbols(input_ids)
136
+
137
+ input_ids_2 = None
138
+ if item.get("phoneme_2"):
139
+ input_ids_2 = item["phoneme_2"].split("|")
140
+ input_ids_2 = [self.remove_unknown_symbols(input_ids_2)]
141
+
142
+ input_ids = self.text_collater(
143
+ [input_ids], input_ids_2).to(dtype=torch.long)
144
+ input_ids = input_ids.to(dtype=torch.long)
145
+
146
+ labels = np.load(item["semantic_path"])
147
+ labels = [str(lbl) for lbl in labels]
148
+
149
+ labels_2 = None
150
+ if item.get("semantic_path_2"):
151
+ labels_2 = np.load(item["semantic_path_2"])
152
+ labels_2 = [[str(lbl) for lbl in labels_2]]
153
+
154
+ labels = self.text_collater([labels], labels_2).to(dtype=torch.long)
155
+
156
+ return {"input_ids": input_ids.squeeze(0), "labels": labels.squeeze(0)}
157
+
158
+ # TODO - remove this to not load to the memory
159
+ def _build(self):
160
+ for manifest_path in self.manifest_path:
161
+ dataset_path = Path(manifest_path).parent
162
+
163
+ with open(manifest_path, "r") as manifest_file:
164
+ manifest_data = json.load(manifest_file)
165
+
166
+ for key, value in tqdm(manifest_data.items()):
167
+ if float(value["duration"]) > self.max_duration:
168
+ continue
169
+ text = value["text"]
170
+ phoneme = value["phoneme"]
171
+ npy_path = f"{dataset_path}/audios-speech-tokenizer/semantic/{key.split('.wav')[0]}.npy" # noqa
172
+ datapoint = {
173
+ "text": text,
174
+ "semantic_path": npy_path,
175
+ "phoneme": phoneme
176
+ }
177
+ self.data.append(datapoint)
178
+
179
+ print(f"Total length of the dataset {manifest_path}: {len(self.data)}")
180
+
181
+ random.shuffle(self.data)
182
+
183
+
184
+ if __name__ == "__main__":
185
+ # Create an instance of the dataset
186
+ manifest_path = "datasets/ljspeech-training-data/dev.json"
187
+ text_tokens_file = "ckpt/unique_text_tokens.k2symbols"
188
+ seq2seq_dataset = ConcatenateSemanticDataset(
189
+ [manifest_path, manifest_path], text_tokens_file)
190
+
191
+ # seq2seq_dataset.phonemize_and_rewrite_manifest()
192
+ batch_size = 1 # Adjust to your desired batch size
193
+ dataloader = DataLoader(
194
+ seq2seq_dataset,
195
+ batch_size=batch_size,
196
+ shuffle=True,
197
+ collate_fn=Collator().collate,
198
+ )
199
+
200
+ for batch in dataloader:
201
+ print(batch["input_ids"])
202
+ print(batch["labels"])
203
+ print(batch["input_ids"][0].unique().max())
204
+ print(batch["input_ids"][0].unique().min())
205
+ print(batch["input_ids"].shape)
206
+ print(batch["labels"].shape)
207
+ break # Stop after the first batch if needed
data/single_speaker_dataset.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Main loading function.
2
+
3
+ Copyright PolyAI Limited.
4
+ """
5
+ import json
6
+ import os
7
+ import random
8
+ from pathlib import Path
9
+
10
+ import numpy as np
11
+ import soundfile as sf
12
+ import torch
13
+ from librosa.util import normalize
14
+ from pyannote.audio import Inference
15
+ from torch.utils import data
16
+
17
+ import constants as c
18
+
19
+
20
+ def random_crop(x, maxseqlen):
21
+ if x.shape[0] >= maxseqlen:
22
+ offset = random.randrange(x.shape[0] - maxseqlen + 1)
23
+ x = x[offset: offset + maxseqlen]
24
+ else:
25
+ offset = 0
26
+ return x, offset
27
+
28
+
29
+ def dynamic_range_compression(x, C=0.3, M=6.5, clip_val=1e-5):
30
+ return (np.log(np.clip(x, a_min=clip_val, a_max=None)) + M) * C
31
+
32
+
33
+ def dynamic_range_decompression(x, C=0.3, M=6.5):
34
+ return np.exp(x / C - M)
35
+
36
+
37
+ class QuantizeDataset(data.Dataset):
38
+ def __init__(self, hp, metapath, datadir=None, speaker_embedding_dir=None):
39
+ self.hp = hp
40
+ self.datadir = Path(datadir)
41
+ self.speaker_embedding_dir = speaker_embedding_dir
42
+ self.sem_mask_id = hp.n_semantic_codes
43
+
44
+ print(f"Loading metadata in {metapath}...")
45
+ with open(metapath, "r") as f:
46
+ self.text = json.load(f)
47
+ if 0 < self.hp.max_dataset_samples < len(self.text):
48
+ self.new_text = {}
49
+ num = 0
50
+ for k, v in self.text.items():
51
+ if num >= self.hp.max_dataset_samples:
52
+ break
53
+ self.new_text[k] = v
54
+ num += 1
55
+ self.text = self.new_text
56
+
57
+ self.datasetbase = [x for x in self.text.keys()]
58
+ self.dataset = [
59
+ os.path.join(self.datadir, x) for x in self.datasetbase]
60
+
61
+ if self.speaker_embedding_dir is None:
62
+ self.spkr_embedding = Inference(
63
+ "pyannote/embedding",
64
+ window="whole",
65
+ use_auth_token=os.environ["HUGGING_FACE_HUB_TOKEN"],
66
+ )
67
+
68
+ # Print statistics:
69
+ n = len(self.dataset)
70
+ print(f"Total {n} examples")
71
+
72
+ self.lengths = [float(v["duration"]) for v in self.text.values()]
73
+ total_duration = sum(self.lengths)
74
+ avglen = total_duration / len(self.lengths)
75
+ maxlen = max(self.lengths)
76
+ minlen = min(self.lengths)
77
+ print(
78
+ f"Average duration of audio: {avglen} sec, "
79
+ "Maximum duration: {maxlen} sec, Minimum duration: {minlen} sec"
80
+ )
81
+
82
+ def __len__(self):
83
+ return len(self.dataset)
84
+
85
+ def load_quantization(self, _name):
86
+ if self.hp.vocoder_type == 'NATIVE':
87
+ metadata = self.text[_name]
88
+ quantization = np.array(metadata["quantization"]).T # ..., 4
89
+ elif self.hp.vocoder_type == 'DAC':
90
+ codes_path = self.datadir.parent / 'audios-dac' / (os.path.splitext(_name)[0] + ".npy") # noqa
91
+ quantization = np.load(codes_path).T # ..., 12
92
+ elif self.hp.vocoder_type == 'ENCODEC':
93
+ codes_path = self.datadir.parent / 'audios-encodec' / (os.path.splitext(_name)[0] + ".npy") # noqa
94
+ quantization = np.load(codes_path).squeeze(0).T # ..., 8
95
+ elif self.hp.vocoder_type == 'SPEECHTOKENIZER':
96
+ codes_path = self.datadir.parent / 'audios-speech-tokenizer/acoustic' / (os.path.splitext(_name)[0] + ".npy") # noqa
97
+ quantization = np.load(codes_path).T # ..., 7
98
+ else:
99
+ raise ValueError(f"Unknown vocoder_type {self.hp.vocoder_type}")
100
+
101
+ return quantization
102
+
103
+ def __getitem__(self, i):
104
+ dataname = self.dataset[i]
105
+ _name = self.datasetbase[i]
106
+ metadata = self.text[_name]
107
+
108
+ # Speaker 1
109
+ acoustic_tokens = self.load_quantization(_name)
110
+ acoustic_tokens = np.pad(
111
+ acoustic_tokens, [[1, 0],[0,0]], constant_values=c.SPKR_1)
112
+
113
+ npy_path = self.datadir.parent / 'audios-speech-tokenizer/semantic' / (os.path.splitext(_name)[0] + ".npy") # noqa
114
+ semantic_tokens = np.load(npy_path)[None]
115
+ semantic_tokens = np.pad(
116
+ semantic_tokens,[[0,0], [1, 0]], constant_values=c.SPKR_1)
117
+
118
+ if "name_2" in metadata:
119
+ wav, _ = sf.read(dataname.split(".")[0] + "_1.wav")
120
+ else:
121
+ wav, _ = sf.read(dataname)
122
+ audio = normalize(wav) * 0.95
123
+ speaker_embedding = self.spkr_embedding(
124
+ {"waveform": torch.FloatTensor(audio).unsqueeze(0),
125
+ "sample_rate": self.hp.sample_rate,}
126
+ ).reshape(1, -1)
127
+ speaker_embedding = np.repeat(
128
+ speaker_embedding, semantic_tokens.shape[1], axis=0)
129
+
130
+ # Speaker 2
131
+ if "text_2" in metadata:
132
+ _name = _name.split(".wav")[0] + "_2.wav"
133
+ acoustic_tokens_2 = self.load_quantization(_name)
134
+ acoustic_tokens_2 = np.pad(
135
+ acoustic_tokens_2, [[1, 0],[0,0]], constant_values=c.SPKR_2)
136
+
137
+ npy_path = self.datadir.parent / 'audios-speech-tokenizer/semantic' / (os.path.splitext(_name)[0] + ".npy") # noqa
138
+ semantic_tokens_2 = np.load(npy_path)[None]
139
+ semantic_tokens_2 = np.pad(
140
+ semantic_tokens_2,[[0,0], [1, 0]], constant_values=c.SPKR_2)
141
+
142
+ wav, _ = sf.read(dataname.split(".wav")[0] + "_2.wav")
143
+ audio = normalize(wav) * 0.95
144
+ speaker_embedding_2 = self.spkr_embedding(
145
+ {"waveform": torch.FloatTensor(audio).unsqueeze(0),
146
+ "sample_rate": self.hp.sample_rate,}
147
+ ).reshape(1, -1)
148
+ speaker_embedding_2 = np.repeat(
149
+ speaker_embedding_2, semantic_tokens_2.shape[1], axis=0)
150
+
151
+ # Merge both speakers
152
+ acoustic_tokens = np.concatenate(
153
+ (acoustic_tokens, acoustic_tokens_2), axis=0)
154
+ semantic_tokens = np.concatenate(
155
+ (semantic_tokens, semantic_tokens_2), axis=1)
156
+ speaker_embedding = np.concatenate(
157
+ (speaker_embedding, speaker_embedding_2), axis=0)
158
+
159
+ speaker_embedding = speaker_embedding[:self.hp.max_length, :]
160
+ acoustic_tokens = acoustic_tokens[:self.hp.max_length, :]
161
+ semantic_tokens = semantic_tokens[:, :self.hp.max_length]
162
+
163
+ # # HACK - we have no 8 lvls pfb30
164
+ # acoustic_tokens = np.concatenate((semantic_tokens.T, acoustic_tokens), axis=1)
165
+ # # END HACK
166
+
167
+ return speaker_embedding, acoustic_tokens, acoustic_tokens, dataname, semantic_tokens # noqa
modules/__init__.py ADDED
File without changes
modules/conformer.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Conformer definition adjusted given the Lucidrain's repo.
2
+ https://github.com/lucidrains/soundstorm-pytorch/blob/main/soundstorm_pytorch/soundstorm.py # noqa
3
+
4
+ Copyright PolyAI Limited.
5
+ """
6
+ from collections import namedtuple
7
+ from functools import wraps
8
+ from typing import Dict, Union
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from einops import rearrange, reduce
13
+ from einops.layers.torch import EinMix, Rearrange
14
+ from torch import einsum, nn
15
+
16
+
17
+ # rotary embedding
18
+ class RotaryEmbedding(nn.Module):
19
+ def __init__(self, dim, theta = 10000):
20
+ super().__init__()
21
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
22
+ self.register_buffer("inv_freq", inv_freq, persistent = False)
23
+
24
+ @property
25
+ def device(self):
26
+ return next(self.buffers()).device
27
+
28
+ def forward(self, seq_len):
29
+ t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq)
30
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
31
+ freqs = torch.cat((freqs, freqs), dim = -1)
32
+ return freqs
33
+
34
+ def rotate_half(x):
35
+ x1, x2 = x.chunk(2, dim=-1)
36
+ return torch.cat((-x2, x1), dim=-1)
37
+
38
+ def apply_rotary_pos_emb(pos, t):
39
+ return (t * pos.cos()) + (rotate_half(t) * pos.sin())
40
+
41
+
42
+ # constants
43
+ EfficientAttentionConfig = namedtuple(
44
+ 'EfficientAttentionConfig',
45
+ ['enable_flash', 'enable_math', 'enable_mem_efficient']
46
+ )
47
+
48
+ # helpers
49
+ def exists(val):
50
+ return val is not None
51
+
52
+ def default(val, d):
53
+ return val if exists(val) else d
54
+
55
+ def divisible_by(numer, denom):
56
+ return (numer % denom) == 0
57
+
58
+ def calc_same_padding(kernel_size):
59
+ pad = kernel_size // 2
60
+ return (pad, pad - (kernel_size + 1) % 2)
61
+
62
+ def eval_decorator(fn):
63
+ @wraps(fn)
64
+ def inner(model, *args, **kwargs):
65
+ was_training = model.training
66
+ model.eval()
67
+ out = fn(model, *args, **kwargs)
68
+ model.train(was_training)
69
+ return out
70
+ return inner
71
+
72
+
73
+ def once(fn):
74
+ called = False
75
+ @wraps(fn)
76
+ def inner(x):
77
+ nonlocal called
78
+ if called:
79
+ return
80
+ called = True
81
+ return fn(x)
82
+ return inner
83
+
84
+ print_once = once(print)
85
+
86
+
87
+ # t5 relative positional bias
88
+ class T5RelativePositionBias(nn.Module):
89
+ def __init__(
90
+ self,
91
+ scale = 1.,
92
+ num_buckets = 32,
93
+ max_distance = 128,
94
+ heads = 8
95
+ ):
96
+ super().__init__()
97
+ self.scale = scale
98
+ self.num_buckets = num_buckets
99
+ self.max_distance = max_distance
100
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
101
+
102
+ @staticmethod
103
+ def _relative_position_bucket(
104
+ relative_position,
105
+ num_buckets = 32,
106
+ max_distance = 128
107
+ ):
108
+ ret = 0
109
+ n = -relative_position
110
+
111
+ num_buckets //= 2
112
+ ret += (n < 0).long() * num_buckets
113
+ n = torch.abs(n)
114
+
115
+ max_exact = num_buckets // 2
116
+ is_small = n < max_exact
117
+
118
+ val_if_large = max_exact + (
119
+ torch.log(n.float() / max_exact) / math.log(
120
+ max_distance / max_exact) * (num_buckets - max_exact)
121
+ ).long()
122
+
123
+ val_if_large = torch.min(
124
+ val_if_large,
125
+ torch.full_like(val_if_large, num_buckets - 1)
126
+ )
127
+
128
+ ret += torch.where(is_small, n, val_if_large)
129
+ return ret
130
+
131
+ @property
132
+ def device(self):
133
+ return next(self.parameters()).device
134
+
135
+ def forward(self, n):
136
+ pos = torch.arange(n, device = self.device).long()
137
+ rel_pos = rearrange(pos, 'j -> 1 j') - rearrange(pos, 'i -> i 1')
138
+
139
+ rp_bucket = self._relative_position_bucket(
140
+ rel_pos, num_buckets = self.num_buckets,
141
+ max_distance = self.max_distance)
142
+ values = self.relative_attention_bias(rp_bucket)
143
+
144
+ bias = rearrange(values, 'i j h -> h i j')
145
+ return bias * self.scale
146
+
147
+
148
+ # main class
149
+ class Attend(nn.Module):
150
+ def __init__(
151
+ self,
152
+ causal = False,
153
+ dropout = 0.,
154
+ flash = False
155
+ ):
156
+ super().__init__()
157
+ self.dropout = dropout
158
+ self.attn_dropout = nn.Dropout(dropout)
159
+
160
+ self.causal = causal
161
+ self.flash = flash
162
+
163
+ # determine efficient attention configs for cuda and cpu
164
+ self.cpu_config = EfficientAttentionConfig(True, True, True)
165
+ self.cuda_config = None
166
+
167
+ if not torch.cuda.is_available() or not flash:
168
+ return
169
+
170
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
171
+
172
+ if device_properties.major == 8 and device_properties.minor == 0:
173
+ print_once('A100 GPU detected, using flash attention if input tensor is on cuda') # noqa
174
+ self.cuda_config = EfficientAttentionConfig(True, True, True)
175
+ else:
176
+ print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') # noqa
177
+ self.cuda_config = EfficientAttentionConfig(False, True, True)
178
+
179
+ def get_mask(self, i, j, device):
180
+ return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1) # noqa
181
+
182
+ def flash_attn(self, q, k, v, mask = None, attn_bias = None):
183
+ _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device # noqa
184
+
185
+ # single headed key / values
186
+
187
+ if k.ndim == 3:
188
+ k = rearrange(k, 'b n d -> b 1 n d')
189
+
190
+ if v.ndim == 3:
191
+ v = rearrange(v, 'b n d -> b 1 n d')
192
+
193
+ # Check if mask exists and expand to compatible shape
194
+ # The mask is B L, so it would have to be expanded to B H N L
195
+ if exists(mask) and mask.ndim != 4:
196
+ mask = rearrange(mask, 'b j -> b 1 1 j')
197
+ mask = mask.expand(-1, heads, q_len, -1)
198
+
199
+ # Check if there is a compatible device for flash attention
200
+ config = self.cuda_config if is_cuda else self.cpu_config
201
+ causal = self.causal
202
+
203
+ # handle attention bias
204
+ if exists(attn_bias):
205
+ mask_value = -torch.finfo(q.dtype).max // 2
206
+ causal_mask = self.get_mask(q_len, k_len, device)
207
+ attn_bias = attn_bias.masked_fill(causal_mask, mask_value)
208
+
209
+ if exists(mask):
210
+ attn_bias = attn_bias.masked_fill(~mask, mask_value)
211
+
212
+ mask = attn_bias
213
+ causal = False
214
+
215
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
216
+ with torch.backends.cuda.sdp_kernel(**config._asdict()):
217
+ out = F.scaled_dot_product_attention(
218
+ q, k, v,
219
+ attn_mask = mask,
220
+ dropout_p = self.dropout if self.training else 0.,
221
+ is_causal = causal
222
+ )
223
+
224
+ return out
225
+
226
+ def forward(self, q, k, v, mask = None, attn_bias = None):
227
+ """
228
+ einstein notation
229
+ b - batch
230
+ h - heads
231
+ n, i, j - sequence length (base sequence length, source, target)
232
+ d - feature dimension
233
+ """
234
+
235
+ q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
236
+
237
+ scale = q.shape[-1] ** -0.5
238
+
239
+ kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
240
+
241
+ if self.flash:
242
+ assert not exists(attn_bias)
243
+ return self.flash_attn(q, k, v, mask = mask)
244
+
245
+ # similarity
246
+
247
+ sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
248
+
249
+ # attention bias
250
+
251
+ if exists(attn_bias):
252
+ sim = sim + attn_bias
253
+
254
+ # causal mask
255
+ if self.causal:
256
+ causal_mask = self.get_mask(q_len, k_len, device)
257
+ sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
258
+
259
+ # key padding mask
260
+ if exists(mask):
261
+ if mask.ndim != 4:
262
+ mask = rearrange(mask, 'b j -> b 1 1 j')
263
+ sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
264
+
265
+ # attention
266
+ attn = sim.softmax(dim=-1)
267
+ attn = self.attn_dropout(attn)
268
+
269
+ # aggregate values
270
+ out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
271
+
272
+ return out
273
+
274
+
275
+ class Swish(nn.Module):
276
+ def forward(self, x):
277
+ return x * x.sigmoid()
278
+
279
+
280
+ class GLU(nn.Module):
281
+ def __init__(self, dim):
282
+ super().__init__()
283
+ self.dim = dim
284
+
285
+ def forward(self, x):
286
+ out, gate = x.chunk(2, dim=self.dim)
287
+ return out * gate.sigmoid()
288
+
289
+
290
+ class DepthWiseConv1d(nn.Module):
291
+ def __init__(self, chan_in, chan_out, kernel_size, padding):
292
+ super().__init__()
293
+ self.padding = padding
294
+ self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in)
295
+
296
+ def forward(self, x):
297
+ x = F.pad(x, self.padding)
298
+ return self.conv(x)
299
+
300
+
301
+ class Scale(nn.Module):
302
+ def __init__(self, scale, fn):
303
+ super().__init__()
304
+ self.fn = fn
305
+ self.scale = scale
306
+
307
+ def forward(self, x, **kwargs):
308
+ return self.fn(x, **kwargs) * self.scale
309
+
310
+
311
+ class ChanLayerNorm(nn.Module):
312
+ def __init__(self, dim):
313
+ super().__init__()
314
+ self.gamma = nn.Parameter(torch.ones(1, dim, 1))
315
+
316
+ def forward(self, x):
317
+ eps = 1e-6 if x.dtype == torch.float32 else 1e-4
318
+ var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
319
+ mean = torch.mean(x, dim = 1, keepdim = True)
320
+ return (x - mean) * var.clamp(min = eps).rsqrt() * self.gamma
321
+
322
+
323
+ class PreNorm(nn.Module):
324
+ def __init__(self, dim, fn):
325
+ super().__init__()
326
+ self.fn = fn
327
+ self.norm = nn.LayerNorm(dim)
328
+
329
+ def forward(self, x, **kwargs):
330
+ x = self.norm(x)
331
+ return self.fn(x, **kwargs)
332
+
333
+
334
+ class Attention(nn.Module):
335
+ def __init__(
336
+ self,
337
+ dim,
338
+ heads = 8,
339
+ dim_head = 64,
340
+ dropout = 0.,
341
+ flash = True
342
+ ):
343
+ super().__init__()
344
+ inner_dim = dim_head * heads
345
+ self.heads= heads
346
+ self.scale = dim_head ** -0.5
347
+
348
+ self.attend = Attend(
349
+ flash = flash,
350
+ dropout = dropout
351
+ )
352
+
353
+ self.dropout = nn.Dropout(dropout)
354
+
355
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
356
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
357
+ self.to_out = nn.Linear(inner_dim, dim)
358
+
359
+ def forward(
360
+ self,
361
+ x,
362
+ context = None,
363
+ mask = None,
364
+ rotary_emb = None,
365
+ attn_bias = None
366
+ ):
367
+ n, device, h, has_context = x.shape[-2], x.device, self.heads, exists(context)
368
+ context = default(context, x)
369
+
370
+ q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
371
+ q, k, v = map(
372
+ lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
373
+
374
+ if exists(rotary_emb):
375
+ q = apply_rotary_pos_emb(rotary_emb, q)
376
+ k = apply_rotary_pos_emb(rotary_emb, k)
377
+
378
+ out = self.attend(q, k, v, mask = mask, attn_bias = attn_bias)
379
+
380
+ out = rearrange(out, 'b h n d -> b n (h d)')
381
+ return self.to_out(out)
382
+
383
+
384
+ class FeedForward(nn.Module):
385
+ def __init__(
386
+ self,
387
+ dim,
388
+ mult = 4,
389
+ dropout = 0.
390
+ ):
391
+ super().__init__()
392
+ self.net = nn.Sequential(
393
+ nn.Linear(dim, dim * mult),
394
+ Swish(),
395
+ nn.Dropout(dropout),
396
+ nn.Linear(dim * mult, dim),
397
+ nn.Dropout(dropout)
398
+ )
399
+
400
+ def forward(self, x):
401
+ return self.net(x)
402
+
403
+
404
+ class ConformerConvModule(nn.Module):
405
+ def __init__(
406
+ self,
407
+ dim,
408
+ causal = False,
409
+ expansion_factor = 2,
410
+ kernel_size = 31,
411
+ dropout = 0.
412
+ ):
413
+ super().__init__()
414
+
415
+ inner_dim = dim * expansion_factor
416
+ padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
417
+
418
+ self.net = nn.Sequential(
419
+ nn.LayerNorm(dim),
420
+ Rearrange('b n c -> b c n'),
421
+ nn.Conv1d(dim, inner_dim * 2, 1),
422
+ GLU(dim=1),
423
+ DepthWiseConv1d(
424
+ inner_dim, inner_dim, kernel_size = kernel_size,
425
+ padding = padding
426
+ ),
427
+ Swish(),
428
+ ChanLayerNorm(inner_dim),
429
+ nn.Conv1d(inner_dim, dim, 1),
430
+ Rearrange('b c n -> b n c'),
431
+ nn.Dropout(dropout)
432
+ )
433
+
434
+ def forward(self, x):
435
+ return self.net(x)
436
+
437
+
438
+ # Conformer Block
439
+ class ConformerBlock(nn.Module):
440
+ def __init__(
441
+ self,
442
+ *,
443
+ dim,
444
+ dim_head = 64,
445
+ heads = 8,
446
+ ff_mult = 4,
447
+ conv_expansion_factor = 2,
448
+ conv_kernel_size = 31,
449
+ attn_dropout = 0.,
450
+ attn_flash = True,
451
+ ff_dropout = 0.,
452
+ conv_dropout = 0.,
453
+ conv_causal = False
454
+ ):
455
+ super().__init__()
456
+ self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
457
+ self.attn = Attention(
458
+ dim = dim, dim_head = dim_head, heads = heads,
459
+ dropout = attn_dropout, flash = attn_flash
460
+ )
461
+ self.conv = ConformerConvModule(
462
+ dim = dim, causal = conv_causal,
463
+ expansion_factor = conv_expansion_factor,
464
+ kernel_size = conv_kernel_size, dropout = conv_dropout
465
+ )
466
+ self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
467
+
468
+ self.attn = PreNorm(dim, self.attn)
469
+ self.ff1 = Scale(0.5, PreNorm(dim, self.ff1))
470
+ self.ff2 = Scale(0.5, PreNorm(dim, self.ff2))
471
+
472
+ self.post_norm = nn.LayerNorm(dim)
473
+
474
+ def forward(
475
+ self,
476
+ x,
477
+ mask = None,
478
+ rotary_emb = None,
479
+ attn_bias = None
480
+ ):
481
+ x = self.ff1(x) + x
482
+ x = self.attn(x, mask = mask, rotary_emb = rotary_emb, attn_bias = attn_bias) + x # noqa
483
+ x = self.conv(x) + x
484
+ x = self.ff2(x) + x
485
+ x = self.post_norm(x)
486
+ return x
487
+
488
+
489
+ # Conformer
490
+ class Conformer(nn.Module):
491
+ def __init__(
492
+ self,
493
+ dim,
494
+ *,
495
+ num_layers,
496
+ dim_head = 64,
497
+ heads = 8,
498
+ ff_mult = 4,
499
+ conv_expansion_factor = 2,
500
+ conv_kernel_size = 31,
501
+ attn_dropout = 0.,
502
+ ff_dropout = 0.,
503
+ conv_dropout = 0.,
504
+ conv_causal = False,
505
+ attn_flash = True,
506
+ t5_rel_pos_bias = False
507
+ ):
508
+ super().__init__()
509
+
510
+ assert not (t5_rel_pos_bias and attn_flash), 'flash attention is not compatible with learned bias' # noqa
511
+
512
+ self.dim = dim
513
+ self.layers = nn.ModuleList([])
514
+
515
+ self.rotary_emb = RotaryEmbedding(
516
+ dim_head) if not t5_rel_pos_bias else None
517
+ self.rel_pos_bias = T5RelativePositionBias(
518
+ dim_head ** 0.5, heads = heads) if t5_rel_pos_bias else None
519
+
520
+ for _ in range(num_layers):
521
+ self.layers.append(ConformerBlock(
522
+ dim = dim,
523
+ dim_head = dim_head,
524
+ heads = heads,
525
+ ff_mult = ff_mult,
526
+ conv_expansion_factor = conv_expansion_factor,
527
+ conv_kernel_size = conv_kernel_size,
528
+ attn_dropout = attn_dropout,
529
+ ff_dropout = ff_dropout,
530
+ conv_dropout = conv_dropout,
531
+ conv_causal = conv_causal,
532
+ attn_flash = attn_flash
533
+ ))
534
+
535
+ def forward(self, x, mask = None):
536
+ seq_len = x.shape[-2]
537
+
538
+ rotary_emb = self.rotary_emb(seq_len) if exists(self.rotary_emb) else None # noqa
539
+ attn_bias = self.rel_pos_bias(seq_len) if exists(self.rel_pos_bias) else None #noqa
540
+
541
+ for block in self.layers:
542
+ x = block(
543
+ x,
544
+ mask = mask,
545
+ rotary_emb = rotary_emb,
546
+ attn_bias = attn_bias
547
+ )
548
+ return x
549
+
550
+
551
+ # conformer with sum reduction across quantized tokens at the beginning,
552
+ # along with heads
553
+ class ConformerWrapper(nn.Module):
554
+ def __init__(
555
+ self,
556
+ *,
557
+ codebook_size,
558
+ num_quantizers,
559
+ conformer: Union[Conformer, Dict[str, any]],
560
+ grouped_quantizers = 1
561
+ ):
562
+ super().__init__()
563
+ self.conformer = conformer
564
+
565
+ if isinstance(conformer, dict):
566
+ self.conformer = Conformer(**self.conformer)
567
+
568
+ dim = self.conformer.dim
569
+
570
+ self.embedding_proj = nn.Sequential(
571
+ nn.Linear(dim * grouped_quantizers, dim),
572
+ nn.LayerNorm(dim)
573
+ ) if grouped_quantizers > 1 else nn.Identity()
574
+
575
+ num_codes_with_mask = codebook_size + 1
576
+ num_effective_quantizers = num_quantizers * grouped_quantizers
577
+
578
+ self.code_embeds = nn.Embedding(
579
+ num_codes_with_mask * num_effective_quantizers, dim)
580
+
581
+ self.register_buffer(
582
+ 'quantizer_offsets',
583
+ torch.arange(num_effective_quantizers) * num_codes_with_mask,
584
+ persistent = False
585
+ )
586
+ self.register_buffer(
587
+ 'mask_tokens', self.quantizer_offsets + num_codes_with_mask,
588
+ persistent = False
589
+ )
590
+
591
+ self.dim = dim
592
+ self.codebook_size = codebook_size
593
+
594
+ self.num_codes_with_mask = num_codes_with_mask
595
+ self.num_quantizers = num_quantizers
596
+ self.grouped_quantizers = grouped_quantizers
597
+
598
+ self.heads = nn.Sequential(
599
+ nn.Linear(dim, dim * num_effective_quantizers),
600
+ Rearrange('b n (h d) -> b (n h) d', h = num_effective_quantizers)
601
+ )
602
+
603
+ # each quantizer codebook would require its own logits weight
604
+ # and bias matrices
605
+ # the amazing einops makes this easy with 'EinMix'
606
+ self.to_logits = nn.Sequential(
607
+ nn.LayerNorm(dim),
608
+ Rearrange('b (n gq) d -> b n gq d', gq = num_effective_quantizers),
609
+ EinMix(
610
+ 'b n gq d -> b n gq l',
611
+ weight_shape = 'gq d l',
612
+ bias_shape = 'gq l',
613
+ gq = num_effective_quantizers,
614
+ l = codebook_size,
615
+ d = dim
616
+ ),
617
+ Rearrange('b ... d -> b (...) d')
618
+ )
619
+
620
+ def forward(
621
+ self,
622
+ x,
623
+ *,
624
+ mask = None,
625
+ cond = None,
626
+ sum_embeds = None,
627
+ return_embeddings = False,
628
+ return_logits_and_embeddings = False
629
+ ):
630
+ """
631
+ einops notation:
632
+ b - batch
633
+ n - sequence
634
+ g - groups
635
+ q - quantizers
636
+ d - feature dimension
637
+ """
638
+
639
+ n, q, g = x.shape[-1], self.num_quantizers, self.grouped_quantizers
640
+ assert divisible_by(n, g * q), 'sequence must be divisible by number of quantizers' # noqa
641
+
642
+ x = rearrange(x, 'b (n gq) -> b n gq', gq = g * q)
643
+ x = x + self.quantizer_offsets
644
+
645
+ x = self.code_embeds(x)
646
+
647
+ x = reduce(x, 'b n (g q) d -> b n (g d)', 'sum', g = g)
648
+
649
+ x = self.embedding_proj(x)
650
+
651
+ if exists(sum_embeds):
652
+ x = x + sum_embeds
653
+
654
+ if exists(cond):
655
+ if cond.ndim == 2:
656
+ cond = rearrange(cond, 'b d -> b 1 d')
657
+
658
+ x = x + cond
659
+
660
+ x = self.conformer(x, mask = mask)
661
+ embeds = self.heads(x)
662
+
663
+ if return_embeddings or not exists(self.to_logits):
664
+ return embeds
665
+
666
+ logits = self.to_logits(embeds)
667
+
668
+ if return_logits_and_embeddings:
669
+ return logits, embeds
670
+
671
+ return logits
modules/masking_logic.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Masking and sampling logic adapted from MaskGIT original paper:
2
+ https://github.com/google-research/maskgit
3
+
4
+ Copyright PolyAI Limited.
5
+ """
6
+ from dataclasses import dataclass
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+
13
+ @dataclass
14
+ class State:
15
+ """Holds decoding state data."""
16
+ # The position of the decoding loop in the length dimension.
17
+ cur_index: None
18
+ # The active sequence log probabilities and finished sequence scores.
19
+ cur_seqs: None
20
+ final_seqs: None
21
+
22
+
23
+ def state_init(init_indices, num_iter, start_iter=0):
24
+ """Initializes the decoding state data structure."""
25
+ cur_index_0 = start_iter
26
+ cur_seqs_0 = init_indices
27
+ final_seqs_0 = torch.unsqueeze(init_indices, 1)
28
+ final_seqs_0 = torch.tile(final_seqs_0, (1, num_iter, 1))
29
+ return State(
30
+ cur_index=cur_index_0, cur_seqs=cur_seqs_0, final_seqs=final_seqs_0)
31
+
32
+
33
+ def schedule(ratio, method="cosine"):
34
+ if method == "uniform":
35
+ mask_ratio = 1. - ratio
36
+ elif "pow" in method:
37
+ exponent = float(method.replace("pow", ""))
38
+ mask_ratio = 1. - ratio**exponent
39
+ elif method == "cosine":
40
+ mask_ratio = np.cos(ratio * (np.pi/2))
41
+
42
+ mask_ratio = np.clip(mask_ratio, 1e-6, 1.)
43
+ return mask_ratio
44
+
45
+
46
+ def mask_by_random_topk(mask_len, probs, temperature=1.0):
47
+ noise = gumbel_noise_like(probs)
48
+ confidence = torch.log(probs) + temperature * noise
49
+ sorted_confidence, _ = torch.sort(confidence, dim=-1)
50
+ # Obtains cut off threshold given the mask lengths.
51
+ cut_off = torch.take_along_dim(sorted_confidence, mask_len.long(), dim=-1)
52
+ # Masks tokens with lower confidence.
53
+ masking = (confidence < cut_off)
54
+ return masking
55
+
56
+
57
+ def gumbel_noise_like(t):
58
+ noise = torch.zeros_like(t).uniform_(1e-20, 1)
59
+ return -torch.log(-torch.log(noise))
60
+
61
+
62
+ def sample_from_logits(
63
+ logits,
64
+ sample: bool = True,
65
+ temperature: float = 1.0,
66
+ top_k: int = None,
67
+ top_p: float = None,
68
+ return_probs: bool = False
69
+ ):
70
+ shp = logits.shape[:-1]
71
+
72
+ # Apply top_k sampling
73
+ if top_k is not None:
74
+ v, _ = logits.topk(top_k)
75
+ logits[logits < v[..., [-1]]] = -float("inf")
76
+
77
+ # Apply top_p (nucleus) sampling
78
+ if top_p is not None and top_p < 1.0:
79
+ v, sorted_indices = logits.sort(descending=True)
80
+ cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1)
81
+
82
+ sorted_indices_to_remove = cumulative_probs > top_p
83
+ # Right shift indices_to_remove to keep 1st token over threshold
84
+ sorted_indices_to_remove = F.pad(
85
+ sorted_indices_to_remove, (1, 0), value=False)[..., :-1]
86
+
87
+ # Compute indices_to_remove in unsorted array
88
+ indices_to_remove = sorted_indices_to_remove.scatter(
89
+ -1, sorted_indices, sorted_indices_to_remove
90
+ )
91
+
92
+ logits[indices_to_remove] = -float("inf")
93
+
94
+ # Perform multinomial sampling after normalizing logits
95
+ probs = (
96
+ F.softmax(logits / temperature, dim=-1)
97
+ if temperature > 0
98
+ else logits.softmax(dim=-1)
99
+ )
100
+ token = (
101
+ probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp)
102
+ if sample
103
+ else logits.argmax(-1)
104
+ )
105
+
106
+ if return_probs:
107
+ token_probs = probs.take_along_dim(
108
+ token.unsqueeze(-1), dim=-1).squeeze(-1)
109
+ return token, token_probs
110
+ else:
111
+ return token
modules/s2a_model.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A2S model definition.
2
+
3
+ Copyright PolyAI Limited.
4
+ """
5
+ from typing import Union
6
+
7
+ import pytorch_lightning as pl
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.optim as optim
12
+ from einops import rearrange
13
+
14
+ import constants as c
15
+ from modules import masking_logic
16
+ from modules.conformer import Conformer
17
+ from modules.masking_logic import (State, mask_by_random_topk,
18
+ sample_from_logits, state_init)
19
+ from utils import load_checkpoint
20
+
21
+
22
+ class Pheme(pl.LightningModule):
23
+ def __init__(self, hp):
24
+ super().__init__()
25
+ self.hp = hp
26
+ self.model = TTSConformer(hp)
27
+ self.cross_entropy = nn.CrossEntropyLoss(
28
+ label_smoothing=self.hp.label_smoothing,
29
+ ignore_index=self.hp.n_codes
30
+ )
31
+ if self.hp.pretrained_path:
32
+ self.load()
33
+ else:
34
+ self.apply(self.init_weights)
35
+
36
+ if self.hp.only_inference:
37
+ self.model.eval()
38
+
39
+ self.save_hyperparameters()
40
+
41
+ def load(self):
42
+ state_dict = load_checkpoint(self.hp.pretrained_path)
43
+ print(f"Parameters loaded from {self.hp.pretrained_path}")
44
+ self.load_state_dict(state_dict, strict=True)
45
+
46
+ def init_weights(self, module):
47
+ if isinstance(module, nn.Linear):
48
+ module.weight.data.normal_(mean=0.0, std=0.02)
49
+ if module.bias is not None:
50
+ module.bias.data.zero_()
51
+ if isinstance(module, nn.Embedding):
52
+ module.weight.data.normal_(mean=0.0, std=0.02)
53
+ module._fill_padding_idx_with_zero()
54
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
55
+ module.bias.data.zero_()
56
+ module.weight.data.fill_(1.0)
57
+ elif isinstance(module, nn.Conv1d):
58
+ module.weight.data.normal_(mean=0.0, std=0.02)
59
+ if module.bias is not None:
60
+ module.bias.data.zero_()
61
+
62
+ def configure_optimizers(self):
63
+ optimizer_adam = optim.AdamW(
64
+ self.parameters(), lr=self.hp.lr,
65
+ betas=(self.hp.adam_beta1, self.hp.adam_beta2))
66
+
67
+ # Learning rate scheduler
68
+ num_training_steps = self.hp.training_step
69
+ num_warmup_steps = self.hp.warmup_step
70
+ num_flat_steps = int(self.hp.optim_flat_percent * num_training_steps)
71
+
72
+ def lambda_lr(current_step: int):
73
+ if current_step < num_warmup_steps:
74
+ return float(current_step) / float(max(1, num_warmup_steps))
75
+ elif current_step < (num_warmup_steps + num_flat_steps):
76
+ return 1.0
77
+ return max(
78
+ 0.0,
79
+ float(num_training_steps - current_step)
80
+ / float(
81
+ max(1, num_training_steps - (num_warmup_steps + num_flat_steps)) # noqa
82
+ ),
83
+ )
84
+
85
+ scheduler_adam = {
86
+ "scheduler": optim.lr_scheduler.LambdaLR(
87
+ optimizer_adam, lambda_lr),
88
+ "interval": "step",
89
+ }
90
+ return [optimizer_adam], [scheduler_adam]
91
+
92
+ def top_k_accuracy(self, y_true, y_pred_probabilities, k):
93
+ _, sorted_indices = torch.sort(y_pred_probabilities, descending=True)
94
+
95
+ # Get the top-k predictions
96
+ top_k_indices = sorted_indices[:, :k]
97
+ expanded_y_true = y_true.unsqueeze(1).expand_as(top_k_indices)
98
+
99
+ # Check if true labels exist in top-k predictions
100
+ hits = torch.sum(torch.eq(top_k_indices, expanded_y_true))
101
+ accuracy = hits.item() / (len(y_true) + 1e-7)
102
+
103
+ return accuracy
104
+
105
+ def training_step(self, batch, batch_idx):
106
+ # Sample training level
107
+ rvq_level = torch.randint(
108
+ 0, min(self.hp.first_n_lvls, self.hp.n_cluster_groups),(1,)).item()
109
+
110
+ target, chosen_tokens, _, _ = self.model(
111
+ batch["tts_quantize_input"], rvq_level, batch["semantic_tokens"],
112
+ batch["quantization_lengths"],
113
+ speaker_emb=batch["speaker"],
114
+ min_seq_length=batch["quantization_lengths"].min().item())
115
+
116
+ # Mask targets and labels
117
+ mask = chosen_tokens
118
+ target = target[mask]
119
+
120
+ labels = batch["tts_quantize_input"][:, :, rvq_level]
121
+ labels = labels[mask]
122
+
123
+ loss = self.cross_entropy(target, labels)
124
+ acc = (target.argmax(-1) == labels).float().mean()
125
+ self.log("train/loss", loss, on_step=True, prog_bar=True)
126
+ self.log("train/acc", acc, on_step=True, prog_bar=True)
127
+ self.log(
128
+ f"train/acc_lvl_{rvq_level}", acc, on_step=True, prog_bar=False)
129
+
130
+ return loss
131
+
132
+ def validation_step(self, batch, batch_idx, dataloader_idx=0):
133
+ speaker_emb = batch["speaker"]
134
+ acoustic_tokens = batch["tts_quantize_input"]
135
+ semantic_tokens = batch["semantic_tokens"]
136
+
137
+ if self.hp.only_inference:
138
+ self.inference(
139
+ acoustic_tokens, semantic_tokens, self.hp.first_n_lvls)
140
+ else:
141
+ rvq_level = torch.randint(
142
+ 0, min(self.hp.first_n_lvls, self.hp.n_cluster_groups),(1,)
143
+ ).item()
144
+
145
+ # FIXME: edge case
146
+ if len(semantic_tokens.shape) == 3:
147
+ semantic_tokens = rearrange(semantic_tokens, "B 1 T -> B T")
148
+
149
+ target, chosen_tokens, _, _ = self.model(
150
+ acoustic_tokens, rvq_level, semantic_tokens,
151
+ torch.tensor([acoustic_tokens.shape[1]]).to(self.device),
152
+ speaker_emb=speaker_emb,
153
+ min_seq_length=acoustic_tokens.shape[1]
154
+ )
155
+
156
+ target = target[chosen_tokens]
157
+ labels = acoustic_tokens[:, :, rvq_level][chosen_tokens]
158
+ loss = self.cross_entropy(target, labels)
159
+
160
+ acc = (target.argmax(-1) == labels).float().mean()
161
+ acc_5 = self.top_k_accuracy(labels, target, 5)
162
+
163
+ self.log(
164
+ f"val/dataset_{dataloader_idx}/loss",
165
+ loss,
166
+ on_epoch=True,
167
+ logger=True,
168
+ add_dataloader_idx=False,
169
+ )
170
+ self.log(
171
+ f"val/dataset_{dataloader_idx}/acc_lvl",
172
+ acc,
173
+ on_epoch=True,
174
+ logger=True,
175
+ add_dataloader_idx=False,
176
+ )
177
+ self.log(
178
+ f"val/dataset_{dataloader_idx}/acc_lvl_{rvq_level}",
179
+ acc,
180
+ on_epoch=True,
181
+ logger=True,
182
+ add_dataloader_idx=False,
183
+ )
184
+ self.log(
185
+ f"val/dataset_{dataloader_idx}/acc_top_5",
186
+ acc_5,
187
+ on_epoch=True,
188
+ logger=True,
189
+ add_dataloader_idx=False,
190
+ )
191
+ self.log(
192
+ f"val/dataset_{dataloader_idx}/acc_top_5_lvl_{rvq_level}",
193
+ acc_5,
194
+ on_epoch=True,
195
+ logger=True,
196
+ add_dataloader_idx=False,
197
+ )
198
+
199
+ def compute_stats(self, logits, labels, mask_ratio=0, rvq_level=0):
200
+ acc = (logits.argmax(-1) == labels).float().mean()
201
+ acc_5 = self.top_k_accuracy(labels, logits, 5)
202
+ acc_10 = self.top_k_accuracy(labels, logits, 10)
203
+
204
+ idx = torch.randperm(logits.shape[0])
205
+ logits_shuffled = logits[idx]
206
+ random = self.top_k_accuracy(labels, logits_shuffled, 10)
207
+ print(f"Mask ratio: {mask_ratio}, Level {rvq_level}: acc {acc},"
208
+ f"acc 5 {acc_5}, acc 10 {acc_10}, quasi random {random}")
209
+
210
+
211
+ class TTSConformer(pl.LightningModule):
212
+ def __init__(self, hp):
213
+ super().__init__()
214
+ self.hp = hp
215
+ self.padding_id = self.hp.n_codes
216
+
217
+ additional_codes = [c.PAD, c.SPKR_1, c.SPKR_2]
218
+
219
+ self.embedding = nn.ModuleList(
220
+ [
221
+ nn.Embedding(
222
+ self.hp.n_codes + len(additional_codes),
223
+ self.hp.hidden_size,
224
+ padding_idx=self.padding_id)
225
+ for _ in range(self.hp.n_cluster_groups)
226
+ ]
227
+ )
228
+
229
+ # Additional modules
230
+ self.semantic_embedding = nn.Embedding(
231
+ self.hp.n_semantic_codes + len(additional_codes),
232
+ self.hp.hidden_size,
233
+ padding_idx=self.padding_id)
234
+
235
+ if self.hp.use_spkr_emb:
236
+ self.spkr_linear = nn.Linear(c.SPKR_EMB_SIZE, self.hp.hidden_size)
237
+
238
+ self.conformer = Conformer(
239
+ dim=self.hp.hidden_size,
240
+ num_layers=self.hp.enc_nlayers,
241
+ heads=self.hp.nheads,
242
+ dim_head=64,
243
+ ff_mult=4, # 512*4=2048
244
+ conv_expansion_factor=2,
245
+ conv_kernel_size=self.hp.depthwise_conv_kernel_size,
246
+ attn_dropout=self.hp.dropout,
247
+ ff_dropout=self.hp.dropout,
248
+ conv_dropout=self.hp.dropout,
249
+ attn_flash=True,
250
+ t5_rel_pos_bias=False
251
+ )
252
+
253
+ self.heads = nn.ModuleList(
254
+ [
255
+ nn.Linear(
256
+ self.hp.hidden_size,
257
+ self.hp.n_codes + len(additional_codes)
258
+ )
259
+ for _ in range(self.hp.n_cluster_groups)
260
+ ]
261
+ )
262
+
263
+ def build_mask_from_lengths(self, length, max_len=None):
264
+ max_len = max_len or length.max().item()
265
+ mask = torch.arange(
266
+ max_len, device=length.device)[None, :] >= length[:, None]
267
+ return mask.bool()
268
+
269
+ @torch.no_grad()
270
+ def create_mask(
271
+ self, B, T, lengths, mask_ratio=None, start_t=None,
272
+ min_seq_length=None
273
+ ):
274
+ # 1. Define the random length of condition tokens given the shortest
275
+ # audio in the batch
276
+ if start_t is None:
277
+ start_t = torch.randint(1, min_seq_length - 1, (1,)).item()
278
+
279
+ # 2. Mask other tokens - sample different masking levels per
280
+ if mask_ratio is None:
281
+ ratio = torch.rand(1).item()
282
+ mask_ratio = masking_logic.schedule(ratio)
283
+
284
+ # Create a random tensor with values between 0 and 1
285
+ random_tensor = torch.rand(
286
+ (B, T - start_t), dtype=torch.float).to(self.device)
287
+ # Create a mask where values less than p are set to True
288
+ initial_mask = random_tensor < mask_ratio
289
+ length_mask = self.build_mask_from_lengths(
290
+ lengths - start_t, T - start_t)
291
+ # we can't pick up tokens past token lengths
292
+ initial_mask = torch.logical_and(initial_mask, ~length_mask)
293
+
294
+ # Constrain ratio to always include some samples
295
+ # If all are False let's pick up at least one:
296
+ if torch.sum(initial_mask) == 0:
297
+ choose_steps = torch.randint(low=0, high=(T - start_t), size=(B,))
298
+ initial_mask[torch.arange(B), choose_steps] = torch.tensor(
299
+ True, device=self.device)
300
+
301
+ # 3. Add condition tokens containing information
302
+ acoustic_token_mask = torch.cat(
303
+ (torch.full((B, start_t), False, device=self.device), initial_mask), # noqa
304
+ 1
305
+ )
306
+
307
+ return acoustic_token_mask, start_t, mask_ratio
308
+
309
+ def process_input(
310
+ self, data, lengths, rvq_level, min_seq_length=None,
311
+ mask_ratio=None, start_t=None, acoustic_token_mask=None
312
+ ):
313
+ """
314
+ data: (B, T, code_level, D)
315
+ rvq_level: int
316
+ """
317
+ B = data.size(0)
318
+ T = data.size(1)
319
+ level_data = data[:, :, rvq_level, :] # [B, T, C, D] -> [B, T, D]
320
+
321
+ # Choose acoustic tokens to mask
322
+ if acoustic_token_mask is None:
323
+ acoustic_token_mask, start_t, mask_ratio = self.create_mask(
324
+ B, T, lengths, mask_ratio=mask_ratio, start_t=start_t,
325
+ min_seq_length=min_seq_length)
326
+ # Remove code information from chosen tokens
327
+ level_data[acoustic_token_mask, :] = 0
328
+
329
+ # Embed only lower rvq_level
330
+ lower_code_data = data[:, :, :rvq_level, :].sum(dim=2)
331
+
332
+ # Combine with chosen tokens at rvq_level.
333
+ # Note: all tokens at rvq_level+1: will be discarded.
334
+ summed_data = torch.add(lower_code_data, level_data)
335
+
336
+ return summed_data, acoustic_token_mask, mask_ratio, start_t
337
+
338
+ def forward(
339
+ self, x, code_level, semantic_tokens, lengths,
340
+ speaker_emb=None, min_seq_length=10, mask_ratio=None, start_t=None,
341
+ acoustic_token_mask=None
342
+ ):
343
+ # FIXME: parallelize this
344
+ batch = []
345
+ for lvl, embed in enumerate(self.embedding[:(code_level + 1)]):
346
+ batch.append(embed(x[:, :, lvl])) # [B T D]
347
+
348
+ x = torch.stack(batch, dim=2) # [B T C D]
349
+ x, acoustic_token_mask, mask_ratio, start_t = self.process_input(
350
+ x, lengths, code_level, min_seq_length=min_seq_length,
351
+ mask_ratio=mask_ratio, start_t=start_t,
352
+ acoustic_token_mask=acoustic_token_mask
353
+ )
354
+
355
+ # Add phoneme embeddings
356
+ # Cross attention for all tokens?
357
+
358
+ # Add semantic tokens
359
+ # HACK ME
360
+ semantic_emb = self.semantic_embedding(semantic_tokens)
361
+ x = torch.add(x, semantic_emb)
362
+ # FIXME pfb30
363
+
364
+ # Merge different modalities
365
+ if self.hp.use_spkr_emb:
366
+ spkr_emb = F.normalize(speaker_emb, dim=-1)
367
+ spkr_emb = self.spkr_linear(
368
+ F.dropout(spkr_emb, self.hp.speaker_embed_dropout)
369
+ )
370
+ x = torch.add(x, spkr_emb)
371
+
372
+ output_frames = self.conformer(x, None)
373
+
374
+ x = self.heads[code_level](output_frames)
375
+
376
+ return x, acoustic_token_mask, mask_ratio, start_t
377
+
378
+ @torch.no_grad()
379
+ def inference(
380
+ self, codes, semantic_tokens,
381
+ length: torch.LongTensor, rvq_levels=7,
382
+ mask_ratio=0.99, maskgit_inference=True,
383
+ start_t: Union[torch.LongTensor, None] = None,
384
+ speaker_emb=None, steps=16
385
+ ):
386
+ # Use half of the recording for the conditioning
387
+ if start_t is None:
388
+ start_t = torch.tensor(int((codes.shape[1]) / 2)).long()
389
+
390
+ start_t = start_t.item()
391
+
392
+ for rvq_level in range(rvq_levels):
393
+ original_codes = torch.clone(codes)
394
+ if rvq_level == 0 and maskgit_inference:
395
+ codes = self.multi_step_inference(
396
+ original_codes, semantic_tokens, length,
397
+ start_t=start_t, vamp_filtering=False,
398
+ speaker_emb=speaker_emb, steps=16
399
+ )
400
+ else:
401
+ codes = self.one_step_inference(
402
+ original_codes, semantic_tokens, length,
403
+ code_level=rvq_level,
404
+ mask_ratio=mask_ratio, start_t=start_t,
405
+ speaker_emb=speaker_emb
406
+ )
407
+
408
+ codes = rearrange(codes, 'T C -> 1 T C')
409
+
410
+ # Remove any padding left
411
+ codes = rearrange(codes, '1 T C -> 1 C T')
412
+ codes = torch.where(codes >= self.hp.n_codes, 0, codes)
413
+ acoustic_tokens = codes
414
+ semantic_tokens = rearrange(semantic_tokens, 'b c -> b 1 c')
415
+ semantic_tokens = torch.where(
416
+ semantic_tokens >= self.hp.n_codes, 0, semantic_tokens)
417
+ codes = torch.cat([semantic_tokens, acoustic_tokens], dim=1)
418
+
419
+ return codes
420
+
421
+ @torch.no_grad()
422
+ def one_step_inference(
423
+ self, original_codes, semantic_tokens, lengths, code_level=0,
424
+ mask_ratio=0.99, start_t=0, inference_setup="argmax", speaker_emb=None
425
+ ):
426
+ codes = torch.clone(original_codes)
427
+ logits, _, _, _ = self.forward(
428
+ codes, code_level, semantic_tokens, lengths,
429
+ mask_ratio=mask_ratio, start_t=start_t,
430
+ speaker_emb=speaker_emb, acoustic_token_mask=False)
431
+
432
+ if inference_setup == "argmax":
433
+ probs = torch.nn.functional.softmax(logits, dim=-1)
434
+ top_indeces = torch.argmax(probs, dim=-1)
435
+
436
+ if inference_setup == "sampling":
437
+ top_indeces = torch.distributions.Categorical(
438
+ logits=logits).sample()
439
+
440
+ codes = rearrange(codes, '1 T C -> T C')
441
+ codes[start_t:, code_level] = top_indeces[0, start_t:]
442
+
443
+ return codes
444
+
445
+ @torch.no_grad()
446
+ def multi_step_inference(
447
+ self, original_codes, semantic_tokens, lengths,
448
+ start_t: torch.LongTensor=None,
449
+ choice_temperature=1.0, start_iter=0,
450
+ steps=16, vamp_filtering=False, speaker_emb=None
451
+ ):
452
+ codes = torch.clone(original_codes)
453
+ code_level = 0
454
+ _, seq_len, _ = original_codes.shape
455
+ mask_token_id = self.padding_id
456
+
457
+ # Get true codes for the prompt
458
+ prompt_mask = codes[:, :start_t, code_level]
459
+
460
+ # Fill up rest with masks
461
+ mask = torch.full(
462
+ (1, seq_len - start_t), mask_token_id, device=self.device)
463
+ inputs = torch.cat((prompt_mask, mask), 1)
464
+
465
+ num_mask_tokens_at_start = torch.sum(inputs == mask_token_id, axis=-1)
466
+
467
+ # Initializes state
468
+ state = state_init(inputs, steps, start_iter=start_iter)
469
+
470
+ def loop_cond_fn(state):
471
+ """Beam search loop termination condition."""
472
+ not_at_end = (state.cur_index < steps)
473
+ return not_at_end
474
+
475
+ while loop_cond_fn(state):
476
+ """Beam search loop state update function."""
477
+ step = state.cur_index
478
+ # Current input ids: [batch_size, seq_length].
479
+ cur_ids = state.cur_seqs
480
+
481
+ # Calls model on current seqs to get next-iteration seqs.
482
+ with torch.no_grad():
483
+ logits, _, _, _ = self.forward(
484
+ rearrange(inputs, 'B T -> B T 1'),
485
+ code_level,
486
+ semantic_tokens, lengths,
487
+ acoustic_token_mask=False,
488
+ speaker_emb=speaker_emb)
489
+
490
+ # Samples the ids using categorical sampling:
491
+ if vamp_filtering:
492
+ typical_mass = 0.2
493
+ typical_min_tokens = 1
494
+ top_p = None
495
+ sample_cutoff = 0.5
496
+ typical_filtering = False
497
+ sampled_ids, selected_probs = sample_from_logits(
498
+ logits, sample=((step / steps) <= sample_cutoff),
499
+ temperature=choice_temperature,
500
+ typical_filtering=typical_filtering,
501
+ typical_mass=typical_mass,
502
+ typical_min_tokens=typical_min_tokens,
503
+ top_k=None, top_p=top_p, return_probs=True,
504
+ )
505
+ else:
506
+ sampled_ids = torch.distributions.Categorical(
507
+ logits=logits).sample()
508
+
509
+ # Just updates the masked tokens.
510
+ unknown_map = (cur_ids == mask_token_id)
511
+ sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)
512
+ # Defines the mask ratio for the next round. The number to mask out
513
+ # is determined by mask_ratio * unknown_number_in_the_beginning.
514
+ ratio = 1. * (step + 1) / steps
515
+ mask_ratio = masking_logic.schedule(ratio)
516
+
517
+ # Updates final seqs with the current sampled_ids.
518
+ final_seqs = torch.clone(state.final_seqs)
519
+ final_seqs[:, step, :] = sampled_ids
520
+ # Computes the probabilities of each selected tokens.
521
+ probs = torch.nn.functional.softmax(logits, dim=-1)
522
+ # Extract the probabilities of sampled ids
523
+ selected_probs = torch.squeeze(
524
+ torch.take_along_dim(
525
+ probs, torch.unsqueeze(sampled_ids, -1) , -1),
526
+ -1
527
+ )
528
+
529
+ # Ignores the tokens given in the input
530
+ # by overwriting their confidence.
531
+ selected_probs = torch.where(
532
+ unknown_map, selected_probs, torch.inf)
533
+ # Gets mask lens for each sample in the
534
+ # batch according to the mask ratio.
535
+ num_to_mask = torch.unsqueeze(
536
+ torch.floor(num_mask_tokens_at_start * mask_ratio), 1)
537
+
538
+ # Keeps at least one of prediction in this
539
+ # round and also masks out at least
540
+ # one and for the next iteration
541
+ num_to_mask = torch.maximum(
542
+ torch.tensor(1),
543
+ torch.minimum(
544
+ torch.sum(unknown_map, dim=-1, keepdim=True) - 1,
545
+ num_to_mask)
546
+ )
547
+ # Adds noise for randomness
548
+ masking = mask_by_random_topk(
549
+ num_to_mask, selected_probs, choice_temperature * (1. - ratio))
550
+ # Masks tokens with lower confidence.
551
+ sampled_ids = torch.where(masking, mask_token_id, sampled_ids)
552
+
553
+ state = State(
554
+ cur_index=state.cur_index + 1,
555
+ cur_seqs=sampled_ids,
556
+ final_seqs=final_seqs
557
+ )
558
+
559
+ codes = torch.clone(original_codes)
560
+ codes = rearrange(codes, '1 T C -> T C')
561
+ codes[:, 0] = state.final_seqs[0][-1]
562
+
563
+ return codes
modules/speech_tokenizer.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Speech tokenizer class.
2
+
3
+ Copyright PolyAI Limited.
4
+ """
5
+ import logging
6
+ import os
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torchaudio
11
+ from speechtokenizer import SpeechTokenizer as ST
12
+
13
+ from modules.tokenizer import BaseTokenizer
14
+
15
+
16
+ class SpeechTokenizer(BaseTokenizer):
17
+ def __init__(self, config_path: str, ckpt_path: str):
18
+ self.device = torch.device(
19
+ "cuda" if torch.cuda.is_available() else "cpu")
20
+ self.model = ST.load_from_checkpoint(
21
+ config_path, ckpt_path).to(self.device)
22
+ self.model.eval()
23
+
24
+ def encode_file(
25
+ self, folder_path: str, destination_folder: str, filename: str):
26
+ dest_path = os.path.join(
27
+ destination_folder, "semantic",
28
+ os.path.splitext(filename)[0] + ".npy"
29
+ )
30
+ dest_path2 = os.path.join(
31
+ destination_folder, "acoustic",
32
+ os.path.splitext(filename)[0] + ".npy"
33
+ )
34
+ if os.path.exists(dest_path) and os.path.exists(dest_path2):
35
+ pass
36
+ else:
37
+ self._create_subfolders(destination_folder=destination_folder)
38
+
39
+ file_path = os.path.join(folder_path, filename)
40
+ wav_info = torchaudio.info(file_path)
41
+ wav_dur_sec = wav_info.num_frames / wav_info.sample_rate
42
+ if wav_dur_sec > 60:
43
+ logging.info(
44
+ f"Skipping {file_path} is too long: {wav_dur_sec:.3f} sec,"
45
+ "can cause CUDA OOM"
46
+ )
47
+ return
48
+ wav, sr = torchaudio.load(file_path)
49
+ if sr != self.model.sample_rate:
50
+ logging.warning(
51
+ "Wav sample rate %(wav_sr)s does not match the model"
52
+ "sampling rate %(model_sr)s. Resampling audio",
53
+ {"wav_sr": sr, "model_sr": self.model.sample_rate},
54
+ )
55
+ wav = torchaudio.functional.resample(
56
+ wav, sr, self.model.sample_rate)
57
+ wav = wav.unsqueeze(0)
58
+ wav = wav.to(self.device)
59
+
60
+ # Extract discrete codes from SpeechTokenizer
61
+ with torch.no_grad():
62
+ codes = self.model.encode(wav) # codes: (n_q, B, T)
63
+
64
+ semantic_tokens = codes[0, 0, :]
65
+ acoustic_tokens = codes[1:, 0, :]
66
+
67
+ # Save the encoding as .npy
68
+ dest_path = os.path.join(
69
+ destination_folder, "acoustic",
70
+ os.path.splitext(filename)[0] + ".npy"
71
+ )
72
+ np.save(dest_path, acoustic_tokens.cpu().numpy())
73
+
74
+ dest_path = os.path.join(
75
+ destination_folder, "semantic",
76
+ os.path.splitext(filename)[0] + ".npy"
77
+ )
78
+ np.save(dest_path, semantic_tokens.cpu().numpy())
79
+
80
+ @staticmethod
81
+ def _create_subfolders(destination_folder: str):
82
+ if not os.path.exists(destination_folder + "/acoustic"):
83
+ os.makedirs(destination_folder + "/acoustic")
84
+
85
+ if not os.path.exists(destination_folder + "/semantic"):
86
+ os.makedirs(destination_folder + "/semantic")
modules/t2s_model.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """T2S model definition.
2
+
3
+ Copyright PolyAI Limited.
4
+ """
5
+ import os
6
+
7
+ import numpy as np
8
+ from torch import nn
9
+ from transformers import EvalPrediction, T5Config, T5ForConditionalGeneration
10
+
11
+ from data.collation import get_text_semantic_token_collater
12
+
13
+
14
+ def compute_custom_metrics(eval_prediction: EvalPrediction):
15
+ # eval_prediction: tuple
16
+ # eval_prediction[0]: tensor of decoder outputs(logits) (n_batch, n_semantic, n_tokens) # noqa
17
+ # eval_prediction[1]: tensor of encoder outputs (n_batch, n_text/n_phone, n_hidden) # noqa
18
+ logits = eval_prediction.predictions[0]
19
+ labels = eval_prediction.label_ids
20
+ n_vocab = logits.shape[-1]
21
+ mask = labels == -100
22
+ top_1 = np.argmax(logits, axis=-1) == labels
23
+ top_1[mask] = False
24
+ top_5 = np.argsort(logits, axis=-1)[:, :, -5:]
25
+ top_5 = np.any(top_5 == np.expand_dims(labels, axis=-1), axis=-1)
26
+ top_5[mask] = False
27
+
28
+ top_10 = np.argsort(logits, axis=-1)[:, :, -10:]
29
+ top_10 = np.any(top_10 == np.expand_dims(labels, axis=-1), axis=-1)
30
+ top_10[mask] = False
31
+
32
+ top_1_accuracy = np.sum(top_1) / np.sum(~mask)
33
+ top_5_accuracy = np.sum(top_5) / np.sum(~mask)
34
+ top_10_accuracy = np.sum(top_10) / np.sum(~mask)
35
+
36
+ return {
37
+ "top_1_accuracy": top_1_accuracy,
38
+ "top_5_accuracy": top_5_accuracy,
39
+ "top_10_accuracy": top_10_accuracy,
40
+ }
41
+
42
+
43
+ class T2S(nn.Module):
44
+ def __init__(self, hp):
45
+ super().__init__()
46
+ self.text_tokens_file = "ckpt/unique_text_tokens.k2symbols"
47
+ self.collater = get_text_semantic_token_collater(self.text_tokens_file)
48
+ self.model_size = hp.model_size
49
+ self.vocab_size = len(self.collater.idx2token)
50
+ self.config = self._define_model_config(self.model_size)
51
+
52
+ print(f"{self.config = }")
53
+ self.t2s = T5ForConditionalGeneration(self.config)
54
+
55
+ def _define_model_config(self, model_size):
56
+ if model_size == "test":
57
+ # n_params = 16M
58
+ d_ff = 16
59
+ d_model = 8
60
+ d_kv = 32
61
+ num_heads = 1
62
+ num_decoder_layers = 1
63
+ num_layers = 1
64
+ elif model_size == "tiny":
65
+ # n_params = 16M
66
+ d_ff = 1024
67
+ d_model = 256
68
+ d_kv = 32
69
+ num_heads = 4
70
+ num_decoder_layers = 4
71
+ num_layers = 4
72
+ elif model_size == "t5small":
73
+ # n_params = 60M
74
+ d_ff = 2048
75
+ d_model = 512
76
+ d_kv = 64
77
+ num_heads = 8
78
+ num_decoder_layers = 6
79
+ num_layers = 6
80
+ elif model_size == "large":
81
+ # n_params = 100M
82
+ d_ff = 2048
83
+ d_model = 512
84
+ d_kv = 64
85
+ num_heads = 8
86
+ num_decoder_layers = 14
87
+ num_layers = 14
88
+ elif model_size == "Large":
89
+ # n_params = 114M
90
+ d_ff = 4096
91
+ d_model = 512
92
+ d_kv = 64
93
+ num_heads = 8
94
+ num_decoder_layers = 6
95
+ num_layers = 10
96
+ else:
97
+ raise ValueError(f"unknown {model_size}")
98
+
99
+ config = T5Config(
100
+ d_ff=d_ff,
101
+ d_model=d_model,
102
+ d_kv=d_kv,
103
+ num_heads=num_heads,
104
+ num_decoder_layers=num_decoder_layers,
105
+ num_layers=num_layers,
106
+ decoder_start_token_id=0,
107
+ eos_token_id=2,
108
+ vocab_size=self.vocab_size,
109
+ )
110
+
111
+ return config
modules/tokenizer.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base tokenizer class.
2
+
3
+ Copyright PolyAI Limited.
4
+ """
5
+ import os
6
+ from asyncio import as_completed
7
+ from concurrent.futures import ThreadPoolExecutor
8
+
9
+ from tqdm import tqdm
10
+
11
+ from utils import measure_duration
12
+
13
+
14
+ class BaseTokenizer:
15
+ @measure_duration
16
+ def encode_files_with_model_seq(
17
+ self, folder_path: str, destination_folder: str):
18
+ # Ensure destination folder exists
19
+ if not os.path.exists(destination_folder):
20
+ os.makedirs(destination_folder)
21
+
22
+ # Go through each file in the folder
23
+ filenames = os.listdir(folder_path)
24
+ # encoding files has no side effects
25
+ for filename in tqdm(filenames):
26
+ self.encode_file(
27
+ folder_path=folder_path,
28
+ destination_folder=destination_folder,
29
+ filename=filename,
30
+ )
31
+
32
+ def get_chunk(self, folder_path, start_percent=0, end_percent=100):
33
+ filenames = os.listdir(folder_path)
34
+ total_files = len(filenames)
35
+
36
+ start_idx = int(total_files * (start_percent / 100))
37
+ end_idx = int(total_files * (end_percent / 100))
38
+
39
+ return filenames[start_idx:end_idx]
40
+
41
+ @measure_duration
42
+ def encode_files_with_model_concurrent(
43
+ self, folder_path: str, destination_folder: str, start_percent: int,
44
+ end_percent: int,
45
+ ):
46
+ # Ensure destination folder exists
47
+ if not os.path.exists(destination_folder):
48
+ os.makedirs(destination_folder)
49
+
50
+ # Go through each file in the folder
51
+ filenames = self.get_chunk(folder_path, start_percent, end_percent)
52
+
53
+ # encoding files has no side effects
54
+ with ThreadPoolExecutor(max_workers=40) as executor:
55
+ futures = [
56
+ executor.submit(
57
+ self.encode_file,
58
+ folder_path=folder_path,
59
+ destination_folder=destination_folder,
60
+ filename=filename,
61
+ )
62
+ for filename in filenames
63
+ ]
64
+ # Wait for all tasks to complete
65
+ for future in as_completed(futures):
66
+ future.result()
67
+
68
+ # Explicitly shut down the thread pool
69
+ executor.shutdown()
70
+
71
+ def encode_file(
72
+ self, folder_path: str, destination_folder: str, filename: str):
73
+ raise NotImplementedError
modules/vocoder.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Vocoder wrapper.
2
+
3
+ Copyright PolyAI Limited.
4
+ """
5
+ import enum
6
+
7
+ import numpy as np
8
+ import soundfile as sf
9
+ import torch
10
+ import torch.nn as nn
11
+ from speechtokenizer import SpeechTokenizer
12
+
13
+
14
+ class VocoderType(enum.Enum):
15
+ SPEECHTOKENIZER = ("SPEECHTOKENIZER", 320)
16
+
17
+ def __init__(self, name, compression_ratio):
18
+ self._name_ = name
19
+ self.compression_ratio = compression_ratio
20
+
21
+ def get_vocoder(self, ckpt_path, config_path, **kwargs):
22
+ if self.name == "SPEECHTOKENIZER":
23
+ if ckpt_path:
24
+ vocoder = STWrapper(ckpt_path, config_path)
25
+ else:
26
+ vocoder = STWrapper()
27
+ else:
28
+ raise ValueError(f"Unknown vocoder type {self.name}")
29
+ return vocoder
30
+
31
+
32
+ class STWrapper(nn.Module):
33
+ def __init__(
34
+ self,
35
+ ckpt_path: str = './ckpt/speechtokenizer/SpeechTokenizer.pt',
36
+ config_path = './ckpt/speechtokenizer/config.json',
37
+ ):
38
+ super().__init__()
39
+ self.model = SpeechTokenizer.load_from_checkpoint(
40
+ config_path, ckpt_path)
41
+
42
+ def eval(self):
43
+ self.model.eval()
44
+
45
+ @torch.no_grad()
46
+ def decode(self, codes: torch.Tensor, verbose: bool = False):
47
+ original_device = codes.device
48
+
49
+ codes = codes.to(self.device)
50
+ audio_array = self.model.decode(codes)
51
+
52
+ return audio_array.to(original_device)
53
+
54
+ def decode_to_file(self, codes_path, out_path) -> None:
55
+ codes = np.load(codes_path)
56
+ codes = torch.from_numpy(codes)
57
+ wav = self.decode(codes).cpu().numpy()
58
+ sf.write(out_path, wav, samplerate=self.model.sample_rate)
59
+
60
+ @torch.no_grad()
61
+ def encode(self, wav, verbose=False, n_quantizers: int = None):
62
+ original_device = wav.device
63
+ wav = wav.to(self.device)
64
+ codes = self.model.encode(wav) # codes: (n_q, B, T)
65
+ return codes.to(original_device)
66
+
67
+ def encode_to_file(self, wav_path, out_path) -> None:
68
+ wav, _ = sf.read(wav_path, dtype='float32')
69
+ wav = torch.from_numpy(wav).unsqueeze(0).unsqueeze(0)
70
+ codes = self.encode(wav).cpu().numpy()
71
+ np.save(out_path, codes)
72
+
73
+ def remove_weight_norm(self):
74
+ pass
75
+
76
+ @property
77
+ def device(self):
78
+ return next(self.model.parameters()).device
79
+
transformer_infer.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference logic.
2
+
3
+ Copyright PolyAI Limited.
4
+ """
5
+ import argparse
6
+ import json
7
+ import logging
8
+ import os
9
+ import time
10
+ from pathlib import Path
11
+
12
+ import numpy as np
13
+ import soundfile as sf
14
+ import torch
15
+ from einops import rearrange
16
+ from librosa.util import normalize
17
+ from pyannote.audio import Inference
18
+ from transformers import GenerationConfig, T5ForConditionalGeneration
19
+
20
+ import constants as c
21
+ from data.collation import get_text_semantic_token_collater
22
+ from data.semantic_dataset import TextTokenizer
23
+ from modules.s2a_model import Pheme
24
+ from modules.vocoder import VocoderType
25
+
26
+ # How many times one token can be generated
27
+ MAX_TOKEN_COUNT = 100
28
+
29
+ logging.basicConfig(level=logging.DEBUG)
30
+ device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
31
+
32
+
33
+ def parse_arguments():
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument(
36
+ "--text", type=str,
37
+ default="I gotta say, I would never expect that to happen!"
38
+ )
39
+ parser.add_argument(
40
+ "--manifest_path", type=str, default="demo/manifest.json")
41
+ parser.add_argument("--outputdir", type=str, default="demo/")
42
+ parser.add_argument("--featuredir", type=str, default="demo/")
43
+ parser.add_argument(
44
+ "--text_tokens_file", type=str,
45
+ default="ckpt/unique_text_tokens.k2symbols"
46
+ )
47
+ parser.add_argument("--t2s_path", type=str, default="ckpt/t2s/")
48
+ parser.add_argument(
49
+ "--a2s_path", type=str, default="ckpt/s2a/s2a.ckpt")
50
+
51
+ parser.add_argument("--target_sample_rate", type=int, default=16_000)
52
+
53
+ parser.add_argument("--temperature", type=float, default=0.7)
54
+ parser.add_argument("--top_k", type=int, default=210)
55
+ parser.add_argument("--voice", type=str, default="male_voice")
56
+
57
+ return parser.parse_args()
58
+
59
+
60
+ class PhemeClient():
61
+ def __init__(self, args):
62
+ self.args = args
63
+ self.outputdir = args.outputdir
64
+ self.target_sample_rate = args.target_sample_rate
65
+ self.featuredir = Path(args.featuredir).expanduser()
66
+ self.collater = get_text_semantic_token_collater(args.text_tokens_file)
67
+ self.phonemizer = TextTokenizer()
68
+
69
+ self.load_manifest(args.manifest_path)
70
+
71
+ # T2S model
72
+ self.t2s = T5ForConditionalGeneration.from_pretrained(args.t2s_path)
73
+ self.t2s = T5ForConditionalGeneration.
74
+ self.t2s.to(device)
75
+ self.t2s.eval()
76
+
77
+ # S2A model
78
+ self.s2a = Pheme.load_from_checkpoint(args.a2s_path)
79
+ self.s2a.to(device=device)
80
+ self.s2a.eval()
81
+
82
+ # Vocoder
83
+ vocoder = VocoderType["SPEECHTOKENIZER"].get_vocoder(None, None)
84
+ self.vocoder = vocoder.to(device)
85
+ self.vocoder.eval()
86
+
87
+ self.spkr_embedding = Inference(
88
+ "pyannote/embedding",
89
+ window="whole",
90
+ use_auth_token=os.environ["HUGGING_FACE_HUB_TOKEN"],
91
+ )
92
+
93
+ def load_manifest(self, input_path):
94
+ input_file = {}
95
+ with open(input_path, "rb") as f:
96
+ for line in f:
97
+ temp = json.loads(line)
98
+ input_file[temp["audio_filepath"].split(".wav")[0]] = temp
99
+ self.input_file = input_file
100
+
101
+ def lazy_decode(self, decoder_output, symbol_table):
102
+ semantic_tokens = map(lambda x: symbol_table[x], decoder_output)
103
+ semantic_tokens = [int(x) for x in semantic_tokens if x.isdigit()]
104
+
105
+ return np.array(semantic_tokens)
106
+
107
+ def infer_text(self, text, voice, sampling_config):
108
+ semantic_prompt = np.load(self.args.featuredir + "/audios-speech-tokenizer/semantic/" + f"{voice}.npy") # noqa
109
+ phones_seq = self.phonemizer(text)[0]
110
+ input_ids = self.collater([phones_seq])
111
+ input_ids = input_ids.type(torch.IntTensor).to(device)
112
+
113
+ labels = [str(lbl) for lbl in semantic_prompt]
114
+ labels = self.collater([labels])[:, :-1]
115
+ decoder_input_ids = labels.to(device).long()
116
+ logging.debug(f"decoder_input_ids: {decoder_input_ids}")
117
+
118
+ counts = 1E10
119
+ while (counts > MAX_TOKEN_COUNT):
120
+ output_ids = self.t2s.generate(
121
+ input_ids, decoder_input_ids=decoder_input_ids,
122
+ generation_config=sampling_config).sequences
123
+
124
+ # check repetitiveness
125
+ _, counts = torch.unique_consecutive(output_ids, return_counts=True)
126
+ counts = max(counts).item()
127
+
128
+ output_semantic = self.lazy_decode(
129
+ output_ids[0], self.collater.idx2token)
130
+
131
+ # remove the prompt
132
+ return output_semantic[len(semantic_prompt):].reshape(1, -1)
133
+
134
+ def _load_speaker_emb(self, element_id_prompt):
135
+ wav, _ = sf.read(self.featuredir / element_id_prompt)
136
+ audio = normalize(wav) * 0.95
137
+ speaker_emb = self.spkr_embedding(
138
+ {
139
+ "waveform": torch.FloatTensor(audio).unsqueeze(0),
140
+ "sample_rate": self.target_sample_rate
141
+ }
142
+ ).reshape(1, -1)
143
+
144
+ return speaker_emb
145
+
146
+ def _load_prompt(self, prompt_file_path):
147
+ element_id_prompt = Path(prompt_file_path).stem
148
+ acoustic_path_prompt = self.featuredir / "audios-speech-tokenizer/acoustic" / f"{element_id_prompt}.npy" # noqa
149
+ semantic_path_prompt = self.featuredir / "audios-speech-tokenizer/semantic" / f"{element_id_prompt}.npy" # noqa
150
+
151
+ acoustic_prompt = np.load(acoustic_path_prompt).squeeze().T
152
+ semantic_prompt = np.load(semantic_path_prompt)[None]
153
+
154
+ return acoustic_prompt, semantic_prompt
155
+
156
+ def infer_acoustic(self, output_semantic, prompt_file_path):
157
+ semantic_tokens = output_semantic.reshape(1, -1)
158
+ acoustic_tokens = np.full(
159
+ [semantic_tokens.shape[1], 7], fill_value=c.PAD)
160
+
161
+ acoustic_prompt, semantic_prompt = self._load_prompt(prompt_file_path) # noqa
162
+
163
+ # Prepend prompt
164
+ acoustic_tokens = np.concatenate(
165
+ [acoustic_prompt, acoustic_tokens], axis=0)
166
+ semantic_tokens = np.concatenate([
167
+ semantic_prompt, semantic_tokens], axis=1)
168
+
169
+ # Add speaker
170
+ acoustic_tokens = np.pad(
171
+ acoustic_tokens, [[1, 0], [0, 0]], constant_values=c.SPKR_1)
172
+ semantic_tokens = np.pad(
173
+ semantic_tokens, [[0,0], [1, 0]], constant_values=c.SPKR_1)
174
+
175
+ speaker_emb = None
176
+ if self.s2a.hp.use_spkr_emb:
177
+ speaker_emb = self._load_speaker_emb(prompt_file_path)
178
+ speaker_emb = np.repeat(
179
+ speaker_emb, semantic_tokens.shape[1], axis=0)
180
+ speaker_emb = torch.from_numpy(speaker_emb).to(device)
181
+ else:
182
+ speaker_emb = None
183
+
184
+ acoustic_tokens = torch.from_numpy(
185
+ acoustic_tokens).unsqueeze(0).to(device).long()
186
+ semantic_tokens = torch.from_numpy(semantic_tokens).to(device).long()
187
+ start_t = torch.tensor(
188
+ [acoustic_prompt.shape[0]], dtype=torch.long, device=device)
189
+ length = torch.tensor([
190
+ semantic_tokens.shape[1]], dtype=torch.long, device=device)
191
+
192
+ codes = self.s2a.model.inference(
193
+ acoustic_tokens,
194
+ semantic_tokens,
195
+ start_t=start_t,
196
+ length=length,
197
+ maskgit_inference=True,
198
+ speaker_emb=speaker_emb
199
+ )
200
+
201
+ # Remove the prompt
202
+ synth_codes = codes[:, :, start_t:]
203
+ synth_codes = rearrange(synth_codes, "b c t -> c b t")
204
+
205
+ return synth_codes
206
+
207
+ def generate_audio(self, text, voice, sampling_config, prompt_file_path):
208
+ start_time = time.time()
209
+ output_semantic = self.infer_text(
210
+ text, voice, sampling_config
211
+ )
212
+ logging.debug(f"semantic_tokens: {time.time() - start_time}")
213
+
214
+ start_time = time.time()
215
+ codes = self.infer_acoustic(output_semantic, prompt_file_path)
216
+ logging.debug(f"acoustic_tokens: {time.time() - start_time}")
217
+
218
+ start_time = time.time()
219
+ audio_array = self.vocoder.decode(codes)
220
+ audio_array = rearrange(audio_array, "1 1 T -> T").cpu().numpy()
221
+ logging.debug(f"vocoder time: {time.time() - start_time}")
222
+
223
+ return audio_array
224
+
225
+ @torch.no_grad()
226
+ def infer(
227
+ self, text, voice="male_voice", temperature=0.7,
228
+ top_k=210, max_new_tokens=750,
229
+ ):
230
+ sampling_config = GenerationConfig.from_pretrained(
231
+ self.args.t2s_path,
232
+ top_k=top_k,
233
+ num_beams=1,
234
+ do_sample=True,
235
+ temperature=temperature,
236
+ num_return_sequences=1,
237
+ max_new_tokens=max_new_tokens,
238
+ return_dict_in_generate=True,
239
+ output_scores=True
240
+ )
241
+
242
+ voice_data = self.input_file[voice]
243
+ prompt_file_path = voice_data["audio_prompt_filepath"]
244
+ text = voice_data["text"] + " " + text
245
+
246
+ audio_array = self.generate_audio(
247
+ text, voice, sampling_config, prompt_file_path)
248
+
249
+ return audio_array
250
+
251
+
252
+ if __name__ == "__main__":
253
+ args = parse_arguments()
254
+ args.outputdir = Path(args.outputdir).expanduser()
255
+ args.outputdir.mkdir(parents=True, exist_ok=True)
256
+ args.manifest_path = Path(args.manifest_path).expanduser()
257
+
258
+ client = PhemeClient(args)
259
+ audio_array = client.infer(args.text, voice=args.voice)
260
+ sf.write(os.path.join(
261
+ args.outputdir, f"{args.voice}.wav"), audio_array,
262
+ args.target_sample_rate
263
+ )
utils/__init__.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Copyright PolyAI Limited."""
2
+ import logging
3
+ import pdb
4
+ import sys
5
+ import traceback
6
+ from functools import wraps
7
+ from time import time
8
+ from typing import List
9
+
10
+ import torch
11
+
12
+ from .symbol_table import SymbolTable
13
+
14
+
15
+ def load_checkpoint(ckpt_path: str) -> dict:
16
+ """
17
+ Loads checkpoint, while matching phone embedding size.
18
+ """
19
+ state_dict: dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
20
+ new_state_dict = dict()
21
+ for p_name in state_dict.keys():
22
+ if p_name.startswith("vocoder"):
23
+ continue
24
+
25
+ new_state_dict[p_name] = state_dict[p_name]
26
+
27
+ return new_state_dict
28
+
29
+
30
+ def breakpoint_on_error(fn):
31
+ """Creates a breakpoint on error
32
+
33
+ Use as a wrapper
34
+
35
+ Args:
36
+ fn: the function
37
+
38
+ Returns:
39
+ inner function
40
+ """
41
+
42
+ def inner(*args, **kwargs):
43
+ try:
44
+ return fn(*args, **kwargs)
45
+ except Exception:
46
+ """Standard python way of creating a breakpoint on error"""
47
+ extype, value, tb = sys.exc_info()
48
+ print(f"extype={extype},\nvalue={value}")
49
+ traceback.print_exc()
50
+ pdb.post_mortem(tb)
51
+
52
+ return inner
53
+
54
+
55
+ def measure_duration(f):
56
+ @wraps(f)
57
+ def wrap(*args, **kw):
58
+ ts = time()
59
+ result = f(*args, **kw)
60
+ te = time()
61
+ logging.debug("func:%r took: %2.4f sec" % (f.__name__, te - ts))
62
+ return result
63
+
64
+ return wrap
65
+
66
+
67
+ def split_metapath(in_paths: List[str]):
68
+ other_paths = []
69
+
70
+ for itm_path in in_paths:
71
+ other_paths.append(itm_path)
72
+
73
+ return other_paths
utils/get_tokens_speech_tokenizer.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Get tokens using the SpeechTokenizer.
2
+
3
+ Apply SpeechTokenizer to extract acoustic and semantic tokens.
4
+ The tokens will be extracted to
5
+ encoding_output/acoustic and encoding_output/semantic.
6
+
7
+ python utils/get_tokens_speech_tokenizer.py \
8
+ --config_path ckpt/speechtokenizer/config.json \
9
+ --ckpt_path ckpt/speechtokenizer/SpeechTokenizer.pt \
10
+ --encoding_input datasets/example/audios \
11
+ --encoding_output datasets/example/audios-speech-tokenizer
12
+
13
+ Copyright PolyAI Limited.
14
+ """
15
+ import argparse
16
+ import pathlib
17
+
18
+ from modules.speech_tokenizer import SpeechTokenizer
19
+
20
+ MQTTS_ROOT_PATH = str(pathlib.Path(__file__).parent.resolve())
21
+
22
+ if __name__ == "__main__":
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument(
25
+ "--config_path",
26
+ type=str,
27
+ help="Path to the SpeechTokenizer config",
28
+ default=MQTTS_ROOT_PATH + "/ckpt/speechtokenizer/config.json",
29
+ )
30
+ parser.add_argument(
31
+ "--ckpt_path",
32
+ type=str,
33
+ help="Path to the SpeechTokenizer checkpoint",
34
+ default=MQTTS_ROOT_PATH + "/ckpt/speechtokenizer/SpeechTokenizer.pt",
35
+ )
36
+ parser.add_argument(
37
+ "--encoding_input",
38
+ type=str,
39
+ help="Path to the input folder for encoding",
40
+ default=MQTTS_ROOT_PATH + "/datasets/giga-training-data/audios",
41
+ )
42
+ parser.add_argument(
43
+ "--encoding_output",
44
+ type=str,
45
+ help="Path where to save the encoded tokens",
46
+ default="/tmp/encoding_output",
47
+ )
48
+ parser.add_argument(
49
+ "--start_percent",
50
+ type=int,
51
+ default=0,
52
+ )
53
+ parser.add_argument(
54
+ "--end_percent",
55
+ type=int,
56
+ default=100,
57
+ )
58
+
59
+ args = parser.parse_args()
60
+ print("Parsed args")
61
+ print(args)
62
+
63
+ tokenizer = SpeechTokenizer(
64
+ config_path=args.config_path,
65
+ ckpt_path=args.ckpt_path,
66
+ )
67
+ tokenizer.encode_files_with_model_concurrent(
68
+ folder_path=args.encoding_input, destination_folder=args.encoding_output,
69
+ start_percent=args.start_percent, end_percent=args.end_percent
70
+ )
utils/symbol_table.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang)
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+
15
+ Copyright PolyAI Limited.
16
+ """
17
+ from dataclasses import dataclass, field
18
+ from typing import Dict, Generic, List, Optional, TypeVar, Union
19
+
20
+ Symbol = TypeVar('Symbol')
21
+
22
+
23
+ # Disable __repr__ otherwise it could freeze e.g. Jupyter.
24
+ @dataclass(repr=False)
25
+ class SymbolTable(Generic[Symbol]):
26
+ '''SymbolTable that maps symbol IDs, found on the FSA arcs to
27
+ actual objects. These objects can be arbitrary Python objects
28
+ that can serve as keys in a dictionary (i.e. they need to be
29
+ hashable and immutable).
30
+
31
+ The SymbolTable can only be read to/written from disk if the
32
+ symbols are strings.
33
+ '''
34
+ _id2sym: Dict[int, Symbol] = field(default_factory=dict)
35
+ '''Map an integer to a symbol.
36
+ '''
37
+
38
+ _sym2id: Dict[Symbol, int] = field(default_factory=dict)
39
+ '''Map a symbol to an integer.
40
+ '''
41
+
42
+ _next_available_id: int = 1
43
+ '''A helper internal field that helps adding new symbols
44
+ to the table efficiently.
45
+ '''
46
+
47
+ eps: Symbol = '<eps>'
48
+ '''Null symbol, always mapped to index 0.
49
+ '''
50
+
51
+ def __post_init__(self):
52
+ for idx, sym in self._id2sym.items():
53
+ assert self._sym2id[sym] == idx
54
+ assert idx >= 0
55
+
56
+ for sym, idx in self._sym2id.items():
57
+ assert idx >= 0
58
+ assert self._id2sym[idx] == sym
59
+
60
+ if 0 not in self._id2sym:
61
+ self._id2sym[0] = self.eps
62
+ self._sym2id[self.eps] = 0
63
+ else:
64
+ assert self._id2sym[0] == self.eps
65
+ assert self._sym2id[self.eps] == 0
66
+
67
+ self._next_available_id = max(self._id2sym) + 1
68
+
69
+ @staticmethod
70
+ def from_str(s: str) -> 'SymbolTable':
71
+ '''Build a symbol table from a string.
72
+
73
+ The string consists of lines. Every line has two fields separated
74
+ by space(s), tab(s) or both. The first field is the symbol and the
75
+ second the integer id of the symbol.
76
+
77
+ Args:
78
+ s:
79
+ The input string with the format described above.
80
+ Returns:
81
+ An instance of :class:`SymbolTable`.
82
+ '''
83
+ id2sym: Dict[int, str] = dict()
84
+ sym2id: Dict[str, int] = dict()
85
+
86
+ for line in s.split('\n'):
87
+ fields = line.split()
88
+ if len(fields) == 0:
89
+ continue # skip empty lines
90
+ assert len(fields) == 2, \
91
+ f'Expect a line with 2 fields. Given: {len(fields)}'
92
+ sym, idx = fields[0], int(fields[1])
93
+ assert sym not in sym2id, f'Duplicated symbol {sym}'
94
+ assert idx not in id2sym, f'Duplicated id {idx}'
95
+ id2sym[idx] = sym
96
+ sym2id[sym] = idx
97
+
98
+ eps = id2sym.get(0, '<eps>')
99
+
100
+ return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=eps)
101
+
102
+ @staticmethod
103
+ def from_file(filename: str) -> 'SymbolTable':
104
+ '''Build a symbol table from file.
105
+
106
+ Every line in the symbol table file has two fields separated by
107
+ space(s), tab(s) or both. The following is an example file:
108
+
109
+ .. code-block::
110
+
111
+ <eps> 0
112
+ a 1
113
+ b 2
114
+ c 3
115
+
116
+ Args:
117
+ filename:
118
+ Name of the symbol table file. Its format is documented above.
119
+
120
+ Returns:
121
+ An instance of :class:`SymbolTable`.
122
+
123
+ '''
124
+ with open(filename, 'r', encoding='utf-8') as f:
125
+ return SymbolTable.from_str(f.read().strip())
126
+
127
+ def to_str(self) -> str:
128
+ '''
129
+ Returns:
130
+ Return a string representation of this object. You can pass
131
+ it to the method ``from_str`` to recreate an identical object.
132
+ '''
133
+ s = ''
134
+ for idx, symbol in sorted(self._id2sym.items()):
135
+ s += f'{symbol} {idx}\n'
136
+ return s
137
+
138
+ def to_file(self, filename: str):
139
+ '''Serialize the SymbolTable to a file.
140
+
141
+ Every line in the symbol table file has two fields separated by
142
+ space(s), tab(s) or both. The following is an example file:
143
+
144
+ .. code-block::
145
+
146
+ <eps> 0
147
+ a 1
148
+ b 2
149
+ c 3
150
+
151
+ Args:
152
+ filename:
153
+ Name of the symbol table file. Its format is documented above.
154
+ '''
155
+ with open(filename, 'w') as f:
156
+ for idx, symbol in sorted(self._id2sym.items()):
157
+ print(symbol, idx, file=f)
158
+
159
+ def add(self, symbol: Symbol, index: Optional[int] = None) -> int:
160
+ '''Add a new symbol to the SymbolTable.
161
+
162
+ Args:
163
+ symbol:
164
+ The symbol to be added.
165
+ index:
166
+ Optional int id to which the symbol should be assigned.
167
+ If it is not available, a ValueError will be raised.
168
+
169
+ Returns:
170
+ The int id to which the symbol has been assigned.
171
+ '''
172
+ # Already in the table? Return its ID.
173
+ if symbol in self._sym2id:
174
+ return self._sym2id[symbol]
175
+ # Specific ID not provided - use next available.
176
+ if index is None:
177
+ index = self._next_available_id
178
+ # Specific ID provided but not available.
179
+ if index in self._id2sym:
180
+ raise ValueError(f"Cannot assign id '{index}' to '{symbol}' - "
181
+ f"already occupied by {self._id2sym[index]}")
182
+ self._sym2id[symbol] = index
183
+ self._id2sym[index] = symbol
184
+
185
+ # Update next available ID if needed
186
+ if self._next_available_id <= index:
187
+ self._next_available_id = index + 1
188
+
189
+ return index
190
+
191
+ def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]:
192
+ '''Get a symbol for an id or get an id for a symbol
193
+
194
+ Args:
195
+ k:
196
+ If it is an id, it tries to find the symbol corresponding
197
+ to the id; if it is a symbol, it tries to find the id
198
+ corresponding to the symbol.
199
+
200
+ Returns:
201
+ An id or a symbol depending on the given `k`.
202
+ '''
203
+ if isinstance(k, int):
204
+ return self._id2sym[k]
205
+ else:
206
+ return self._sym2id[k]
207
+
208
+ def merge(self, other: 'SymbolTable') -> 'SymbolTable':
209
+ '''Create a union of two SymbolTables.
210
+ Raises an AssertionError if the same IDs are occupied by
211
+ different symbols.
212
+
213
+ Args:
214
+ other:
215
+ A symbol table to merge with ``self``.
216
+
217
+ Returns:
218
+ A new symbol table.
219
+ '''
220
+ self._check_compatible(other)
221
+
222
+ id2sym = {**self._id2sym, **other._id2sym}
223
+ sym2id = {**self._sym2id, **other._sym2id}
224
+
225
+ return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=self.eps)
226
+
227
+ def _check_compatible(self, other: 'SymbolTable') -> None:
228
+ # Epsilon compatibility
229
+ assert self.eps == other.eps, f'Mismatched epsilon symbol: ' \
230
+ f'{self.eps} != {other.eps}'
231
+ # IDs compatibility
232
+ common_ids = set(self._id2sym).intersection(other._id2sym)
233
+ for idx in common_ids:
234
+ assert self[idx] == other[idx], f'ID conflict for id: {idx}, ' \
235
+ f'self[idx] = "{self[idx]}", ' \
236
+ f'other[idx] = "{other[idx]}"'
237
+ # Symbols compatibility
238
+ common_symbols = set(self._sym2id).intersection(other._sym2id)
239
+ for sym in common_symbols:
240
+ assert self[sym] == other[sym], f'ID conflict for id: {sym}, ' \
241
+ f'self[sym] = "{self[sym]}", ' \
242
+ f'other[sym] = "{other[sym]}"'
243
+
244
+ def __getitem__(self, item: Union[int, Symbol]) -> Union[Symbol, int]:
245
+ return self.get(item)
246
+
247
+ def __contains__(self, item: Union[int, Symbol]) -> bool:
248
+ if isinstance(item, int):
249
+ return item in self._id2sym
250
+ else:
251
+ return item in self._sym2id
252
+
253
+ def __len__(self) -> int:
254
+ return len(self._id2sym)
255
+
256
+ def __eq__(self, other: 'SymbolTable') -> bool:
257
+ if len(self) != len(other):
258
+ return False
259
+
260
+ for s in self.symbols:
261
+ if self[s] != other[s]:
262
+ return False
263
+
264
+ return True
265
+
266
+ @property
267
+ def ids(self) -> List[int]:
268
+ '''Returns a list of integer IDs corresponding to the symbols.
269
+ '''
270
+ ans = list(self._id2sym.keys())
271
+ ans.sort()
272
+ return ans
273
+
274
+ @property
275
+ def symbols(self) -> List[Symbol]:
276
+ '''Returns a list of symbols (e.g., strings) corresponding to
277
+ the integer IDs.
278
+ '''
279
+ ans = list(self._sym2id.keys())
280
+ ans.sort()
281
+ return ans