litagin commited on
Commit
26be912
·
1 Parent(s): 49ec68e
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ venv/
2
+ __pycache__/
AR/__init__.py ADDED
File without changes
AR/data/__init__.py ADDED
File without changes
AR/data/bucket_sampler.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/bucketsampler.py
2
+ import itertools
3
+ import math
4
+ import random
5
+ from random import shuffle
6
+ from typing import Iterator
7
+ from typing import Optional
8
+ from typing import TypeVar
9
+
10
+ import torch
11
+ import torch.distributed as dist
12
+ from torch.utils.data import Dataset
13
+ from torch.utils.data import Sampler
14
+
15
+ __all__ = [
16
+ "DistributedBucketSampler",
17
+ ]
18
+
19
+ T_co = TypeVar("T_co", covariant=True)
20
+
21
+
22
+ class DistributedBucketSampler(Sampler[T_co]):
23
+ r"""
24
+ sort the dataset wrt. input length
25
+ divide samples into buckets
26
+ sort within buckets
27
+ divide buckets into batches
28
+ sort batches
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ dataset: Dataset,
34
+ num_replicas: Optional[int] = None,
35
+ rank: Optional[int] = None,
36
+ shuffle: bool = True,
37
+ seed: int = 0,
38
+ drop_last: bool = False,
39
+ batch_size: int = 32,
40
+ ) -> None:
41
+ if num_replicas is None:
42
+ if not dist.is_available():
43
+ raise RuntimeError("Requires distributed package to be available")
44
+ num_replicas = dist.get_world_size()
45
+ if rank is None:
46
+ if not dist.is_available():
47
+ raise RuntimeError("Requires distributed package to be available")
48
+ rank = dist.get_rank()
49
+ torch.cuda.set_device(rank)
50
+ if rank >= num_replicas or rank < 0:
51
+ raise ValueError(
52
+ "Invalid rank {}, rank should be in the interval"
53
+ " [0, {}]".format(rank, num_replicas - 1)
54
+ )
55
+ self.dataset = dataset
56
+ self.num_replicas = num_replicas
57
+ self.rank = rank
58
+ self.epoch = 0
59
+ self.drop_last = drop_last
60
+ # If the dataset length is evenly divisible by # of replicas, then there
61
+ # is no need to drop any data, since the dataset will be split equally.
62
+ if (
63
+ self.drop_last and len(self.dataset) % self.num_replicas != 0
64
+ ): # type: ignore[arg-type]
65
+ # Split to nearest available length that is evenly divisible.
66
+ # This is to ensure each rank receives the same amount of data when
67
+ # using this Sampler.
68
+ self.num_samples = math.ceil(
69
+ (len(self.dataset) - self.num_replicas)
70
+ / self.num_replicas # type: ignore[arg-type]
71
+ )
72
+ else:
73
+ self.num_samples = math.ceil(
74
+ len(self.dataset) / self.num_replicas
75
+ ) # type: ignore[arg-type]
76
+ self.total_size = self.num_samples * self.num_replicas
77
+ self.shuffle = shuffle
78
+ self.seed = seed
79
+ self.batch_size = batch_size
80
+ self.id_with_length = self._get_sample_lengths()
81
+ self.id_buckets = self.make_buckets(bucket_width=2.0)
82
+
83
+ def _get_sample_lengths(self):
84
+ id_with_lengths = []
85
+ for i in range(len(self.dataset)):
86
+ id_with_lengths.append((i, self.dataset.get_sample_length(i)))
87
+ id_with_lengths.sort(key=lambda x: x[1])
88
+ return id_with_lengths
89
+
90
+ def make_buckets(self, bucket_width: float = 2.0):
91
+ buckets = []
92
+ cur = []
93
+ max_sec = bucket_width
94
+ for id, sec in self.id_with_length:
95
+ if sec < max_sec:
96
+ cur.append(id)
97
+ else:
98
+ buckets.append(cur)
99
+ cur = [id]
100
+ max_sec += bucket_width
101
+ if len(cur) > 0:
102
+ buckets.append(cur)
103
+ return buckets
104
+
105
+ def __iter__(self) -> Iterator[T_co]:
106
+ if self.shuffle:
107
+ # deterministically shuffle based on epoch and seed
108
+ g = torch.Generator()
109
+ g.manual_seed(self.seed + self.epoch)
110
+ random.seed(self.epoch + self.seed)
111
+ shuffled_bucket = []
112
+ for buc in self.id_buckets:
113
+ buc_copy = buc.copy()
114
+ shuffle(buc_copy)
115
+ shuffled_bucket.append(buc_copy)
116
+ grouped_batch_size = self.batch_size * self.num_replicas
117
+ shuffled_bucket = list(itertools.chain(*shuffled_bucket))
118
+ n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
119
+ batches = [
120
+ shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size]
121
+ for b in range(n_batch)
122
+ ]
123
+ shuffle(batches)
124
+ indices = list(itertools.chain(*batches))
125
+ else:
126
+ # type: ignore[arg-type]
127
+ indices = list(range(len(self.dataset)))
128
+
129
+ if not self.drop_last:
130
+ # add extra samples to make it evenly divisible
131
+ padding_size = self.total_size - len(indices)
132
+ if padding_size <= len(indices):
133
+ indices += indices[:padding_size]
134
+ else:
135
+ indices += (indices * math.ceil(padding_size / len(indices)))[
136
+ :padding_size
137
+ ]
138
+ else:
139
+ # remove tail of data to make it evenly divisible.
140
+ indices = indices[: self.total_size]
141
+ assert len(indices) == self.total_size
142
+
143
+ # subsample
144
+ indices = indices[self.rank : self.total_size : self.num_replicas]
145
+ assert len(indices) == self.num_samples
146
+
147
+ return iter(indices)
148
+
149
+ def __len__(self) -> int:
150
+ return self.num_samples
151
+
152
+ def set_epoch(self, epoch: int) -> None:
153
+ r"""
154
+ Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
155
+ use a different random ordering for each epoch. Otherwise, the next iteration of this
156
+ sampler will yield the same ordering.
157
+
158
+ Args:
159
+ epoch (int): Epoch number.
160
+ """
161
+ self.epoch = epoch
AR/data/data_module.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/data_module.py
2
+ from pytorch_lightning import LightningDataModule
3
+ from AR.data.bucket_sampler import DistributedBucketSampler
4
+ from AR.data.dataset import Text2SemanticDataset
5
+ from torch.utils.data import DataLoader
6
+
7
+
8
+ class Text2SemanticDataModule(LightningDataModule):
9
+ def __init__(
10
+ self,
11
+ config,
12
+ train_semantic_path,
13
+ train_phoneme_path,
14
+ dev_semantic_path=None,
15
+ dev_phoneme_path=None,
16
+ ):
17
+ super().__init__()
18
+ self.config = config
19
+ self.train_semantic_path = train_semantic_path
20
+ self.train_phoneme_path = train_phoneme_path
21
+ self.dev_semantic_path = dev_semantic_path
22
+ self.dev_phoneme_path = dev_phoneme_path
23
+ self.num_workers = self.config["data"]["num_workers"]
24
+
25
+ def prepare_data(self):
26
+ pass
27
+
28
+ def setup(self, stage=None, output_logs=False):
29
+ self._train_dataset = Text2SemanticDataset(
30
+ phoneme_path=self.train_phoneme_path,
31
+ semantic_path=self.train_semantic_path,
32
+ max_sec=self.config["data"]["max_sec"],
33
+ pad_val=self.config["data"]["pad_val"],
34
+ )
35
+ self._dev_dataset = self._train_dataset
36
+ # self._dev_dataset = Text2SemanticDataset(
37
+ # phoneme_path=self.dev_phoneme_path,
38
+ # semantic_path=self.dev_semantic_path,
39
+ # max_sample=self.config['data']['max_eval_sample'],
40
+ # max_sec=self.config['data']['max_sec'],
41
+ # pad_val=self.config['data']['pad_val'])
42
+
43
+ def train_dataloader(self):
44
+ batch_size = self.config["train"]["batch_size"]
45
+ sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
46
+ return DataLoader(
47
+ self._train_dataset,
48
+ batch_size=batch_size,
49
+ sampler=sampler,
50
+ collate_fn=self._train_dataset.collate,
51
+ num_workers=self.num_workers,
52
+ persistent_workers=True,
53
+ prefetch_factor=16,
54
+ )
55
+
56
+ def val_dataloader(self):
57
+ return DataLoader(
58
+ self._dev_dataset,
59
+ batch_size=1,
60
+ shuffle=False,
61
+ collate_fn=self._train_dataset.collate,
62
+ num_workers=max(self.num_workers, 12),
63
+ persistent_workers=True,
64
+ prefetch_factor=16,
65
+ )
66
+
67
+ # 这个会使用到嘛?
68
+ def test_dataloader(self):
69
+ return DataLoader(
70
+ self._dev_dataset,
71
+ batch_size=1,
72
+ shuffle=False,
73
+ collate_fn=self._train_dataset.collate,
74
+ )
AR/data/dataset.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/t2s_dataset.py
2
+ import pdb
3
+ import sys
4
+
5
+ # sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
6
+ import traceback, os
7
+ from typing import Dict
8
+ from typing import List
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import torch, json
13
+ from torch.utils.data import DataLoader
14
+ from torch.utils.data import Dataset
15
+ from transformers import AutoTokenizer
16
+
17
+ from text import cleaned_text_to_sequence
18
+
19
+ # from config import exp_dir
20
+
21
+
22
+ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0):
23
+ seq = sequences[0]
24
+ ndim = seq.ndim
25
+ if axis < 0:
26
+ axis += ndim
27
+ dtype = seq.dtype
28
+ pad_value = dtype.type(pad_value)
29
+ seq_lengths = [seq.shape[axis] for seq in sequences]
30
+ max_length = np.max(seq_lengths)
31
+
32
+ padded_sequences = []
33
+ for seq, length in zip(sequences, seq_lengths):
34
+ padding = (
35
+ [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
36
+ )
37
+ padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
38
+ padded_sequences.append(padded_seq)
39
+ batch = np.stack(padded_sequences)
40
+ return batch
41
+
42
+
43
+ class Text2SemanticDataset(Dataset):
44
+ """dataset class for text tokens to semantic model training."""
45
+
46
+ def __init__(
47
+ self,
48
+ phoneme_path: str,
49
+ semantic_path: str,
50
+ max_sample: int = None,
51
+ max_sec: int = 100,
52
+ pad_val: int = 1024,
53
+ # min value of phoneme/sec
54
+ min_ps_ratio: int = 3,
55
+ # max value of phoneme/sec
56
+ max_ps_ratio: int = 25,
57
+ ) -> None:
58
+ super().__init__()
59
+
60
+ self.semantic_data = pd.read_csv(
61
+ semantic_path, delimiter="\t", encoding="utf-8"
62
+ )
63
+ # get dict
64
+ self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
65
+ self.path3 = "%s/3-bert" % (
66
+ os.path.basename(phoneme_path)
67
+ ) # "%s/3-bert"%exp_dir#bert_dir
68
+ self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
69
+ assert os.path.exists(self.path2)
70
+ assert os.path.exists(self.path6)
71
+ self.phoneme_data = {}
72
+ with open(self.path2, "r", encoding="utf8") as f:
73
+ lines = f.read().strip("\n").split("\n")
74
+
75
+ for line in lines:
76
+ tmp = line.split("\t")
77
+ if len(tmp) != 4:
78
+ continue
79
+ self.phoneme_data[tmp[0]] = [tmp[1], tmp[2], tmp[3]]
80
+
81
+ # self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item()
82
+ # pad for semantic tokens
83
+ self.PAD: int = pad_val
84
+ # self.hz = 25
85
+ # with open("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert/configs/s2.json", "r") as f:data = f.read()
86
+ # data=json.loads(data)["model"]["semantic_frame_rate"]#50hz
87
+ # self.hz=int(data[:-2])#
88
+ self.hz = int(os.environ.get("hz", "25hz")[:-2])
89
+
90
+ # max seconds of semantic token
91
+ self.max_sec = max_sec
92
+ self.min_ps_ratio = min_ps_ratio
93
+ self.max_ps_ratio = max_ps_ratio
94
+
95
+ if max_sample is not None:
96
+ self.semantic_data = self.semantic_data[:max_sample]
97
+
98
+ # {idx: (semantic, phoneme)}
99
+ # semantic list, phoneme list
100
+ self.semantic_phoneme = []
101
+ self.item_names = []
102
+
103
+ self.inited = False
104
+
105
+ if not self.inited:
106
+ # 调用初始化函数
107
+ self.init_batch()
108
+ self.inited = True
109
+ del self.semantic_data
110
+ del self.phoneme_data
111
+ # self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large")
112
+ # self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large")
113
+
114
+ def init_batch(self):
115
+ semantic_data_len = len(self.semantic_data)
116
+ phoneme_data_len = len(self.phoneme_data.keys())
117
+ print("semantic_data_len:", semantic_data_len)
118
+ print("phoneme_data_len:", phoneme_data_len)
119
+ print(self.semantic_data)
120
+ idx = 0
121
+ num_not_in = 0
122
+ num_deleted_bigger = 0
123
+ num_deleted_ps = 0
124
+ for i in range(semantic_data_len):
125
+ # 先依次遍历
126
+ # get str
127
+ item_name = self.semantic_data.iloc[i,0]
128
+ # print(self.phoneme_data)
129
+ try:
130
+ phoneme, word2ph, text = self.phoneme_data[item_name]
131
+ except Exception:
132
+ traceback.print_exc()
133
+ # print(f"{item_name} not in self.phoneme_data !")
134
+ num_not_in += 1
135
+ continue
136
+
137
+ semantic_str = self.semantic_data.iloc[i,1]
138
+ # get token list
139
+ semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
140
+ # (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
141
+ # 过滤掉太长的样本
142
+ if (
143
+ len(semantic_ids) > self.max_sec * self.hz
144
+ ): #########1###根据token个数推测总时长过滤时长60s(config里)#40*25=1k
145
+ num_deleted_bigger += 1
146
+ continue
147
+ # (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理####
148
+ phoneme = phoneme.split(" ")
149
+
150
+ try:
151
+ phoneme_ids = cleaned_text_to_sequence(phoneme)
152
+ except:
153
+ traceback.print_exc()
154
+ # print(f"{item_name} not in self.phoneme_data !")
155
+ num_not_in += 1
156
+ continue
157
+ # if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行
158
+ if (
159
+ len(phoneme_ids) > self.max_sec * self.hz / 2.5
160
+ ): ###########2:改为恒定限制为semantic/2.5就行
161
+ num_deleted_ps += 1
162
+ continue
163
+ # if len(semantic_ids) > 1000:###########3
164
+ # num_deleted_bigger += 1
165
+ # continue
166
+
167
+ ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
168
+
169
+ if (
170
+ ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio
171
+ ): ##########4#3~25#每秒多少个phone
172
+ num_deleted_ps += 1
173
+ # print(item_name)
174
+ continue
175
+
176
+ self.semantic_phoneme.append((semantic_ids, phoneme_ids))
177
+ idx += 1
178
+ self.item_names.append(item_name)
179
+
180
+ min_num = 100 # 20直接不补#30补了也不存ckpt
181
+ leng = len(self.semantic_phoneme)
182
+ if leng < min_num:
183
+ tmp1 = self.semantic_phoneme
184
+ tmp2 = self.item_names
185
+ self.semantic_phoneme = []
186
+ self.item_names = []
187
+ for _ in range(max(2, int(min_num / leng))):
188
+ self.semantic_phoneme += tmp1
189
+ self.item_names += tmp2
190
+ if num_not_in > 0:
191
+ print(f"there are {num_not_in} semantic datas not in phoneme datas")
192
+ if num_deleted_bigger > 0:
193
+ print(
194
+ f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds"
195
+ )
196
+ if num_deleted_ps > 0:
197
+ # 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
198
+ print(
199
+ f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
200
+ )
201
+ """
202
+ there are 31 semantic datas not in phoneme datas
203
+ deleted 34 audios who's duration are bigger than 54 seconds
204
+ deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3
205
+ dataset.__len__(): 366463
206
+
207
+ """
208
+ # 345410 for LibriTTS
209
+ print("dataset.__len__():", self.__len__())
210
+
211
+ def __get_item_names__(self) -> List[str]:
212
+ return self.item_names
213
+
214
+ def __len__(self) -> int:
215
+ return len(self.semantic_phoneme)
216
+
217
+ def __getitem__(self, idx: int) -> Dict:
218
+ semantic_ids, phoneme_ids = self.semantic_phoneme[idx]
219
+ item_name = self.item_names[idx]
220
+ phoneme_ids_len = len(phoneme_ids)
221
+ # semantic tokens target
222
+ semantic_ids_len = len(semantic_ids)
223
+
224
+ flag = 0
225
+ path_bert = "%s/%s.pt" % (self.path3, item_name)
226
+ if os.path.exists(path_bert) == True:
227
+ bert_feature = torch.load(path_bert, map_location="cpu")
228
+ else:
229
+ flag = 1
230
+ if flag == 1:
231
+ # bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32)
232
+ bert_feature = None
233
+ else:
234
+ assert bert_feature.shape[-1] == len(phoneme_ids)
235
+ return {
236
+ "idx": idx,
237
+ "phoneme_ids": phoneme_ids,
238
+ "phoneme_ids_len": phoneme_ids_len,
239
+ "semantic_ids": semantic_ids,
240
+ "semantic_ids_len": semantic_ids_len,
241
+ "bert_feature": bert_feature,
242
+ }
243
+
244
+ def get_sample_length(self, idx: int):
245
+ semantic_ids = self.semantic_phoneme[idx][0]
246
+ sec = 1.0 * len(semantic_ids) / self.hz
247
+ return sec
248
+
249
+ def collate(self, examples: List[Dict]) -> Dict:
250
+ sample_index: List[int] = []
251
+ phoneme_ids: List[torch.Tensor] = []
252
+ phoneme_ids_lens: List[int] = []
253
+ semantic_ids: List[torch.Tensor] = []
254
+ semantic_ids_lens: List[int] = []
255
+ # return
256
+
257
+ for item in examples:
258
+ sample_index.append(item["idx"])
259
+ phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
260
+ semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64))
261
+ phoneme_ids_lens.append(item["phoneme_ids_len"])
262
+ semantic_ids_lens.append(item["semantic_ids_len"])
263
+
264
+ # pad 0
265
+ phoneme_ids = batch_sequences(phoneme_ids)
266
+ semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD)
267
+
268
+ # # convert each batch to torch.tensor
269
+ phoneme_ids = torch.tensor(phoneme_ids)
270
+ semantic_ids = torch.tensor(semantic_ids)
271
+ phoneme_ids_lens = torch.tensor(phoneme_ids_lens)
272
+ semantic_ids_lens = torch.tensor(semantic_ids_lens)
273
+ bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens))
274
+ bert_padded.zero_()
275
+
276
+ for idx, item in enumerate(examples):
277
+ bert = item["bert_feature"]
278
+ if bert != None:
279
+ bert_padded[idx, :, : bert.shape[-1]] = bert
280
+
281
+ return {
282
+ # List[int]
283
+ "ids": sample_index,
284
+ # torch.Tensor (B, max_phoneme_length)
285
+ "phoneme_ids": phoneme_ids,
286
+ # torch.Tensor (B)
287
+ "phoneme_ids_len": phoneme_ids_lens,
288
+ # torch.Tensor (B, max_semantic_ids_length)
289
+ "semantic_ids": semantic_ids,
290
+ # torch.Tensor (B)
291
+ "semantic_ids_len": semantic_ids_lens,
292
+ # torch.Tensor (B, 1024, max_phoneme_length)
293
+ "bert_feature": bert_padded,
294
+ }
295
+
296
+
297
+ if __name__ == "__main__":
298
+ root_dir = "/data/docker/liujing04/gpt-vits/prepare/dump_mix/"
299
+ dataset = Text2SemanticDataset(
300
+ phoneme_path=root_dir + "phoneme_train.npy",
301
+ semantic_path=root_dir + "semantic_train.tsv",
302
+ )
303
+
304
+ batch_size = 12
305
+ dataloader = DataLoader(
306
+ dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False
307
+ )
308
+ for i, batch in enumerate(dataloader):
309
+ if i % 1000 == 0:
310
+ print(i)
311
+ # if i == 0:
312
+ # print('batch["ids"]:', batch["ids"])
313
+ # print('batch["phoneme_ids"]:', batch["phoneme_ids"],
314
+ # batch["phoneme_ids"].shape)
315
+ # print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"],
316
+ # batch["phoneme_ids_len"].shape)
317
+ # print('batch["semantic_ids"]:', batch["semantic_ids"],
318
+ # batch["semantic_ids"].shape)
319
+ # print('batch["semantic_ids_len"]:', batch["semantic_ids_len"],
320
+ # batch["semantic_ids_len"].shape)
AR/models/__init__.py ADDED
File without changes
AR/models/t2s_lightning_module.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
2
+ import os, sys
3
+
4
+ now_dir = os.getcwd()
5
+ sys.path.append(now_dir)
6
+ from typing import Dict
7
+
8
+ import torch
9
+ from pytorch_lightning import LightningModule
10
+ from AR.models.t2s_model import Text2SemanticDecoder
11
+ from AR.modules.lr_schedulers import WarmupCosineLRSchedule
12
+ from AR.modules.optim import ScaledAdam
13
+
14
+
15
+ class Text2SemanticLightningModule(LightningModule):
16
+ def __init__(self, config, output_dir, is_train=True):
17
+ super().__init__()
18
+ self.config = config
19
+ self.top_k = 3
20
+ self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
21
+ pretrained_s1 = config.get("pretrained_s1")
22
+ if pretrained_s1 and is_train:
23
+ # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
24
+ print(
25
+ self.load_state_dict(
26
+ torch.load(pretrained_s1, map_location="cpu")["weight"]
27
+ )
28
+ )
29
+ if is_train:
30
+ self.automatic_optimization = False
31
+ self.save_hyperparameters()
32
+ self.eval_dir = output_dir / "eval"
33
+ self.eval_dir.mkdir(parents=True, exist_ok=True)
34
+
35
+ def training_step(self, batch: Dict, batch_idx: int):
36
+ opt = self.optimizers()
37
+ scheduler = self.lr_schedulers()
38
+ loss, acc = self.model.forward(
39
+ batch["phoneme_ids"],
40
+ batch["phoneme_ids_len"],
41
+ batch["semantic_ids"],
42
+ batch["semantic_ids_len"],
43
+ batch["bert_feature"],
44
+ )
45
+ self.manual_backward(loss)
46
+ if batch_idx > 0 and batch_idx % 4 == 0:
47
+ opt.step()
48
+ opt.zero_grad()
49
+ scheduler.step()
50
+
51
+ self.log(
52
+ "total_loss",
53
+ loss,
54
+ on_step=True,
55
+ on_epoch=True,
56
+ prog_bar=True,
57
+ sync_dist=True,
58
+ )
59
+ self.log(
60
+ "lr",
61
+ scheduler.get_last_lr()[0],
62
+ on_epoch=True,
63
+ prog_bar=True,
64
+ sync_dist=True,
65
+ )
66
+ self.log(
67
+ f"top_{self.top_k}_acc",
68
+ acc,
69
+ on_step=True,
70
+ on_epoch=True,
71
+ prog_bar=True,
72
+ sync_dist=True,
73
+ )
74
+
75
+ def validation_step(self, batch: Dict, batch_idx: int):
76
+ return
77
+
78
+ # # get loss
79
+ # loss, acc = self.model.forward(
80
+ # batch['phoneme_ids'], batch['phoneme_ids_len'],
81
+ # batch['semantic_ids'], batch['semantic_ids_len'],
82
+ # batch['bert_feature']
83
+ # )
84
+ #
85
+ # self.log(
86
+ # "val_total_loss",
87
+ # loss,
88
+ # on_step=True,
89
+ # on_epoch=True,
90
+ # prog_bar=True,
91
+ # sync_dist=True)
92
+ # self.log(
93
+ # f"val_top_{self.top_k}_acc",
94
+ # acc,
95
+ # on_step=True,
96
+ # on_epoch=True,
97
+ # prog_bar=True,
98
+ # sync_dist=True)
99
+ #
100
+ # # get infer output
101
+ # semantic_len = batch['semantic_ids'].size(1)
102
+ # prompt_len = min(int(semantic_len * 0.5), 150)
103
+ # prompt = batch['semantic_ids'][:, :prompt_len]
104
+ # pred_semantic = self.model.infer(batch['phoneme_ids'],
105
+ # batch['phoneme_ids_len'], prompt,
106
+ # batch['bert_feature']
107
+ # )
108
+ # save_name = f'semantic_toks_{batch_idx}.pt'
109
+ # save_path = os.path.join(self.eval_dir, save_name)
110
+ # torch.save(pred_semantic.detach().cpu(), save_path)
111
+
112
+ def configure_optimizers(self):
113
+ model_parameters = self.model.parameters()
114
+ parameters_names = []
115
+ parameters_names.append(
116
+ [name_param_pair[0] for name_param_pair in self.model.named_parameters()]
117
+ )
118
+ lm_opt = ScaledAdam(
119
+ model_parameters,
120
+ lr=0.01,
121
+ betas=(0.9, 0.95),
122
+ clipping_scale=2.0,
123
+ parameters_names=parameters_names,
124
+ show_dominant_parameters=False,
125
+ clipping_update_period=1000,
126
+ )
127
+
128
+ return {
129
+ "optimizer": lm_opt,
130
+ "lr_scheduler": {
131
+ "scheduler": WarmupCosineLRSchedule(
132
+ lm_opt,
133
+ init_lr=self.config["optimizer"]["lr_init"],
134
+ peak_lr=self.config["optimizer"]["lr"],
135
+ end_lr=self.config["optimizer"]["lr_end"],
136
+ warmup_steps=self.config["optimizer"]["warmup_steps"],
137
+ total_steps=self.config["optimizer"]["decay_steps"],
138
+ )
139
+ },
140
+ }
AR/models/t2s_model.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
2
+ import torch
3
+ from tqdm import tqdm
4
+
5
+ from AR.models.utils import make_pad_mask
6
+ from AR.models.utils import (
7
+ topk_sampling,
8
+ sample,
9
+ logits_to_probs,
10
+ multinomial_sample_one_no_sync,
11
+ )
12
+ from AR.modules.embedding import SinePositionalEmbedding
13
+ from AR.modules.embedding import TokenEmbedding
14
+ from AR.modules.transformer import LayerNorm
15
+ from AR.modules.transformer import TransformerEncoder
16
+ from AR.modules.transformer import TransformerEncoderLayer
17
+ from torch import nn
18
+ from torch.nn import functional as F
19
+ from torchmetrics.classification import MulticlassAccuracy
20
+
21
+ default_config = {
22
+ "embedding_dim": 512,
23
+ "hidden_dim": 512,
24
+ "num_head": 8,
25
+ "num_layers": 12,
26
+ "num_codebook": 8,
27
+ "p_dropout": 0.0,
28
+ "vocab_size": 1024 + 1,
29
+ "phoneme_vocab_size": 512,
30
+ "EOS": 1024,
31
+ }
32
+
33
+
34
+ class Text2SemanticDecoder(nn.Module):
35
+ def __init__(self, config, norm_first=False, top_k=3):
36
+ super(Text2SemanticDecoder, self).__init__()
37
+ self.model_dim = config["model"]["hidden_dim"]
38
+ self.embedding_dim = config["model"]["embedding_dim"]
39
+ self.num_head = config["model"]["head"]
40
+ self.num_layers = config["model"]["n_layer"]
41
+ self.norm_first = norm_first
42
+ self.vocab_size = config["model"]["vocab_size"]
43
+ self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
44
+ self.p_dropout = config["model"]["dropout"]
45
+ self.EOS = config["model"]["EOS"]
46
+ self.norm_first = norm_first
47
+ assert self.EOS == self.vocab_size - 1
48
+ # should be same as num of kmeans bin
49
+ # assert self.EOS == 1024
50
+ self.bert_proj = nn.Linear(1024, self.embedding_dim)
51
+ self.ar_text_embedding = TokenEmbedding(
52
+ self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
53
+ )
54
+ self.ar_text_position = SinePositionalEmbedding(
55
+ self.embedding_dim, dropout=0.1, scale=False, alpha=True
56
+ )
57
+ self.ar_audio_embedding = TokenEmbedding(
58
+ self.embedding_dim, self.vocab_size, self.p_dropout
59
+ )
60
+ self.ar_audio_position = SinePositionalEmbedding(
61
+ self.embedding_dim, dropout=0.1, scale=False, alpha=True
62
+ )
63
+
64
+ self.h = TransformerEncoder(
65
+ TransformerEncoderLayer(
66
+ d_model=self.model_dim,
67
+ nhead=self.num_head,
68
+ dim_feedforward=self.model_dim * 4,
69
+ dropout=0.1,
70
+ batch_first=True,
71
+ norm_first=norm_first,
72
+ ),
73
+ num_layers=self.num_layers,
74
+ norm=LayerNorm(self.model_dim) if norm_first else None,
75
+ )
76
+
77
+ self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
78
+ self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
79
+
80
+ self.ar_accuracy_metric = MulticlassAccuracy(
81
+ self.vocab_size,
82
+ top_k=top_k,
83
+ average="micro",
84
+ multidim_average="global",
85
+ ignore_index=self.EOS,
86
+ )
87
+
88
+ def forward(self, x, x_lens, y, y_lens, bert_feature):
89
+ """
90
+ x: phoneme_ids
91
+ y: semantic_ids
92
+ """
93
+ x = self.ar_text_embedding(x)
94
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
95
+ x = self.ar_text_position(x)
96
+ x_mask = make_pad_mask(x_lens)
97
+
98
+ y_mask = make_pad_mask(y_lens)
99
+ y_mask_int = y_mask.type(torch.int64)
100
+ codes = y.type(torch.int64) * (1 - y_mask_int)
101
+
102
+ # Training
103
+ # AR Decoder
104
+ y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
105
+ x_len = x_lens.max()
106
+ y_len = y_lens.max()
107
+ y_emb = self.ar_audio_embedding(y)
108
+ y_pos = self.ar_audio_position(y_emb)
109
+
110
+ xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
111
+ ar_xy_padding_mask = xy_padding_mask
112
+
113
+ x_attn_mask = F.pad(
114
+ torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
115
+ (0, y_len),
116
+ value=True,
117
+ )
118
+ y_attn_mask = F.pad(
119
+ torch.triu(
120
+ torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
121
+ diagonal=1,
122
+ ),
123
+ (x_len, 0),
124
+ value=False,
125
+ )
126
+ xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
127
+ bsz, src_len = x.shape[0], x_len + y_len
128
+ _xy_padding_mask = (
129
+ ar_xy_padding_mask.view(bsz, 1, 1, src_len)
130
+ .expand(-1, self.num_head, -1, -1)
131
+ .reshape(bsz * self.num_head, 1, src_len)
132
+ )
133
+ xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
134
+ new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
135
+ new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
136
+ xy_attn_mask = new_attn_mask
137
+ # x 和完整的 y 一次性输入模型
138
+ xy_pos = torch.concat([x, y_pos], dim=1)
139
+ xy_dec, _ = self.h(
140
+ (xy_pos, None),
141
+ mask=xy_attn_mask,
142
+ )
143
+ logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
144
+ # loss
145
+ # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
146
+ loss = F.cross_entropy(logits, targets, reduction="sum")
147
+ acc = self.ar_accuracy_metric(logits.detach(), targets).item()
148
+ return loss, acc
149
+
150
+ # 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
151
+ def infer(
152
+ self,
153
+ x,
154
+ x_lens,
155
+ prompts,
156
+ bert_feature,
157
+ top_k: int = -100,
158
+ early_stop_num: int = -1,
159
+ temperature: float = 1.0,
160
+ ):
161
+ x = self.ar_text_embedding(x)
162
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
163
+ x = self.ar_text_position(x)
164
+
165
+ # AR Decoder
166
+ y = prompts
167
+ prefix_len = y.shape[1]
168
+ x_len = x.shape[1]
169
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
170
+ stop = False
171
+ for _ in tqdm(range(1500)):
172
+ y_emb = self.ar_audio_embedding(y)
173
+ y_pos = self.ar_audio_position(y_emb)
174
+ # x 和逐渐增长的 y 一起输入给模型
175
+ xy_pos = torch.concat([x, y_pos], dim=1)
176
+ y_len = y.shape[1]
177
+ x_attn_mask_pad = F.pad(
178
+ x_attn_mask,
179
+ (0, y_len),
180
+ value=True,
181
+ )
182
+ y_attn_mask = F.pad(
183
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
184
+ (x_len, 0),
185
+ value=False,
186
+ )
187
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
188
+ y.device
189
+ )
190
+
191
+ xy_dec, _ = self.h(
192
+ (xy_pos, None),
193
+ mask=xy_attn_mask,
194
+ )
195
+ logits = self.ar_predict_layer(xy_dec[:, -1])
196
+ samples = topk_sampling(
197
+ logits, top_k=top_k, top_p=1.0, temperature=temperature
198
+ )
199
+
200
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
201
+ print("use early stop num:", early_stop_num)
202
+ stop = True
203
+
204
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
205
+ # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
206
+ stop = True
207
+ if stop:
208
+ if prompts.shape[1] == y.shape[1]:
209
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
210
+ print("bad zero prediction")
211
+ print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
212
+ break
213
+ # 本次生成的 semantic_ids 和之前的 y 构成新的 y
214
+ # print(samples.shape)#[1,1]#第一个1是bs
215
+ # import os
216
+ # os._exit(2333)
217
+ y = torch.concat([y, samples], dim=1)
218
+ return y
219
+
220
+ def pad_y_eos(self, y, y_mask_int, eos_id):
221
+ targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
222
+ y_mask_int, (0, 1), value=1
223
+ )
224
+ # 错位
225
+ return targets[:, :-1], targets[:, 1:]
226
+
227
+ def infer_panel(
228
+ self,
229
+ x, #####全部文本token
230
+ x_lens,
231
+ prompts, ####参考音频token
232
+ bert_feature,
233
+ top_k: int = -100,
234
+ early_stop_num: int = -1,
235
+ temperature: float = 1.0,
236
+ ):
237
+ x = self.ar_text_embedding(x)
238
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
239
+ x = self.ar_text_position(x)
240
+
241
+ # AR Decoder
242
+ y = prompts
243
+ prefix_len = y.shape[1]
244
+ x_len = x.shape[1]
245
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
246
+ stop = False
247
+ # print(1111111,self.num_layers)
248
+ cache = {
249
+ "all_stage": self.num_layers,
250
+ "k": [None] * self.num_layers, ###根据配置自己手写
251
+ "v": [None] * self.num_layers,
252
+ # "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
253
+ "y_emb": None, ##只需要对最新的samples求emb,再拼历史的就行
254
+ # "logits":None,###原版就已经只对结尾求再拼接了,不用管
255
+ # "xy_dec":None,###不需要,本来只需要最后一个做logits
256
+ "first_infer": 1,
257
+ "stage": 0,
258
+ }
259
+ for idx in tqdm(range(1500)):
260
+ if cache["first_infer"] == 1:
261
+ y_emb = self.ar_audio_embedding(y)
262
+ else:
263
+ y_emb = torch.cat(
264
+ [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
265
+ )
266
+ cache["y_emb"] = y_emb
267
+ y_pos = self.ar_audio_position(y_emb)
268
+ # x 和逐渐增长的 y 一起输入给模型
269
+ if cache["first_infer"] == 1:
270
+ xy_pos = torch.concat([x, y_pos], dim=1)
271
+ else:
272
+ xy_pos = y_pos[:, -1:]
273
+ y_len = y_pos.shape[1]
274
+ ###以下3个不做缓存
275
+ if cache["first_infer"] == 1:
276
+ x_attn_mask_pad = F.pad(
277
+ x_attn_mask,
278
+ (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
279
+ value=True,
280
+ )
281
+ y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
282
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
283
+ (x_len, 0),
284
+ value=False,
285
+ )
286
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
287
+ y.device
288
+ )
289
+ else:
290
+ ###最右边一列(是错的)
291
+ # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
292
+ # xy_attn_mask[:,-1]=False
293
+ ###最下面一行(是对的)
294
+ xy_attn_mask = torch.zeros(
295
+ (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
296
+ )
297
+ # pdb.set_trace()
298
+ ###缓存重头戏
299
+ # print(1111,xy_pos.shape,xy_attn_mask.shape,x_len,y_len)
300
+ xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
301
+ logits = self.ar_predict_layer(
302
+ xy_dec[:, -1]
303
+ ) ##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的
304
+ # samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
305
+ samples = sample(
306
+ logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35
307
+ )[0].unsqueeze(0)
308
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
309
+ print("use early stop num:", early_stop_num)
310
+ stop = True
311
+
312
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
313
+ # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
314
+ stop = True
315
+ if stop:
316
+ if prompts.shape[1] == y.shape[1]:
317
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
318
+ print("bad zero prediction")
319
+ print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
320
+ break
321
+ # 本次生成的 semantic_ids 和之前的 y 构成新的 y
322
+ # print(samples.shape)#[1,1]#第一个1是bs
323
+ y = torch.concat([y, samples], dim=1)
324
+ cache["first_infer"] = 0
325
+ return y, idx
AR/models/utils.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/utils.py\
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def sequence_mask(length, max_length=None):
7
+ if max_length is None:
8
+ max_length = length.max()
9
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
10
+ return x.unsqueeze(0) < length.unsqueeze(1)
11
+
12
+
13
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
14
+ """
15
+ Args:
16
+ lengths:
17
+ A 1-D tensor containing sentence lengths.
18
+ max_len:
19
+ The length of masks.
20
+ Returns:
21
+ Return a 2-D bool tensor, where masked positions
22
+ are filled with `True` and non-masked positions are
23
+ filled with `False`.
24
+
25
+ #>>> lengths = torch.tensor([1, 3, 2, 5])
26
+ #>>> make_pad_mask(lengths)
27
+ tensor([[False, True, True, True, True],
28
+ [False, False, False, True, True],
29
+ [False, False, True, True, True],
30
+ [False, False, False, False, False]])
31
+ """
32
+ assert lengths.ndim == 1, lengths.ndim
33
+ max_len = max(max_len, lengths.max())
34
+ n = lengths.size(0)
35
+ seq_range = torch.arange(0, max_len, device=lengths.device)
36
+ expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
37
+
38
+ return expaned_lengths >= lengths.unsqueeze(-1)
39
+
40
+
41
+ # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
42
+ def top_k_top_p_filtering(
43
+ logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
44
+ ):
45
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
46
+ Args:
47
+ logits: logits distribution shape (batch size, vocabulary size)
48
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
49
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
50
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
51
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
52
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
53
+ """
54
+ if top_k > 0:
55
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
56
+ # Remove all tokens with a probability less than the last token of the top-k
57
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
58
+ logits[indices_to_remove] = filter_value
59
+
60
+ if top_p < 1.0:
61
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
62
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
63
+
64
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
65
+ sorted_indices_to_remove = cumulative_probs > top_p
66
+ if min_tokens_to_keep > 1:
67
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
68
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
69
+ # Shift the indices to the right to keep also the first token above the threshold
70
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
71
+ sorted_indices_to_remove[..., 0] = 0
72
+
73
+ # scatter sorted tensors to original indexing
74
+ indices_to_remove = sorted_indices_to_remove.scatter(
75
+ 1, sorted_indices, sorted_indices_to_remove
76
+ )
77
+ logits[indices_to_remove] = filter_value
78
+ return logits
79
+
80
+
81
+ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
82
+ # temperature: (`optional`) float
83
+ # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
84
+ # top_k: (`optional`) int
85
+ # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
86
+ # top_p: (`optional`) float
87
+ # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
88
+
89
+ # Temperature (higher temperature => more likely to sample low probability tokens)
90
+ if temperature != 1.0:
91
+ logits = logits / temperature
92
+ # Top-p/top-k filtering
93
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
94
+ # Sample
95
+ token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
96
+ return token
97
+
98
+
99
+ from typing import Optional, Tuple
100
+
101
+
102
+ def multinomial_sample_one_no_sync(
103
+ probs_sort,
104
+ ): # Does multinomial sampling without a cuda synchronization
105
+ q = torch.empty_like(probs_sort).exponential_(1)
106
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
107
+
108
+
109
+ def logits_to_probs(
110
+ logits,
111
+ previous_tokens: Optional[torch.Tensor] = None,
112
+ temperature: float = 1.0,
113
+ top_k: Optional[int] = None,
114
+ top_p: Optional[int] = None,
115
+ repetition_penalty: float = 1.0,
116
+ ):
117
+ previous_tokens = previous_tokens.squeeze()
118
+ # print(logits.shape,previous_tokens.shape)
119
+ # pdb.set_trace()
120
+ if previous_tokens is not None and repetition_penalty != 1.0:
121
+ previous_tokens = previous_tokens.long()
122
+ score = torch.gather(logits, dim=0, index=previous_tokens)
123
+ score = torch.where(
124
+ score < 0, score * repetition_penalty, score / repetition_penalty
125
+ )
126
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
127
+
128
+ if top_p is not None and top_p < 1.0:
129
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
130
+ cum_probs = torch.cumsum(
131
+ torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
132
+ )
133
+ sorted_indices_to_remove = cum_probs > top_p
134
+ sorted_indices_to_remove[0] = False # keep at least one option
135
+ indices_to_remove = sorted_indices_to_remove.scatter(
136
+ dim=0, index=sorted_indices, src=sorted_indices_to_remove
137
+ )
138
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
139
+
140
+ logits = logits / max(temperature, 1e-5)
141
+
142
+ if top_k is not None:
143
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
144
+ pivot = v.select(-1, -1).unsqueeze(-1)
145
+ logits = torch.where(logits < pivot, -float("Inf"), logits)
146
+
147
+ probs = torch.nn.functional.softmax(logits, dim=-1)
148
+ return probs
149
+
150
+
151
+ def sample(
152
+ logits,
153
+ previous_tokens: Optional[torch.Tensor] = None,
154
+ **sampling_kwargs,
155
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
156
+ probs = logits_to_probs(
157
+ logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
158
+ )
159
+ idx_next = multinomial_sample_one_no_sync(probs)
160
+ return idx_next, probs
AR/modules/__init__.py ADDED
File without changes
AR/modules/activation.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn import Linear
7
+ from torch.nn import Module
8
+ from torch.nn.init import constant_
9
+ from torch.nn.init import xavier_normal_
10
+ from torch.nn.init import xavier_uniform_
11
+ from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
12
+ from torch.nn.parameter import Parameter
13
+
14
+ from torch.nn import functional as F
15
+ from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
16
+
17
+ F.multi_head_attention_forward = multi_head_attention_forward_patched
18
+
19
+
20
+ class MultiheadAttention(Module):
21
+ r"""Allows the model to jointly attend to information
22
+ from different representation subspaces as described in the paper:
23
+ `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
24
+
25
+ Multi-Head Attention is defined as:
26
+
27
+ .. math::
28
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
29
+
30
+ where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
31
+
32
+ ``forward()`` will use a special optimized implementation if all of the following
33
+ conditions are met:
34
+
35
+ - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
36
+ restriction will be loosened in the future.)
37
+ - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
38
+ - training is disabled (using ``.eval()``)
39
+ - dropout is 0
40
+ - ``add_bias_kv`` is ``False``
41
+ - ``add_zero_attn`` is ``False``
42
+ - ``batch_first`` is ``True`` and the input is batched
43
+ - ``kdim`` and ``vdim`` are equal to ``embed_dim``
44
+ - at most one of ``key_padding_mask`` or ``attn_mask`` is passed
45
+ - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
46
+ nor ``attn_mask`` is passed
47
+
48
+ If the optimized implementation is in use, a
49
+ `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
50
+ ``query``/``key``/``value`` to represent padding more efficiently than using a
51
+ padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
52
+ will be returned, and an additional speedup proportional to the fraction of the input
53
+ that is padding can be expected.
54
+
55
+ Args:
56
+ embed_dim: Total dimension of the model.
57
+ num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
58
+ across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
59
+ dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
60
+ bias: If specified, adds bias to input / output projection layers. Default: ``True``.
61
+ add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
62
+ add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
63
+ Default: ``False``.
64
+ kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
65
+ vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
66
+ batch_first: If ``True``, then the input and output tensors are provided
67
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
68
+
69
+ Examples::
70
+
71
+ >>> # xdoctest: +SKIP
72
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
73
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
74
+
75
+ """
76
+ __constants__ = ["batch_first"]
77
+ bias_k: Optional[torch.Tensor]
78
+ bias_v: Optional[torch.Tensor]
79
+
80
+ def __init__(
81
+ self,
82
+ embed_dim,
83
+ num_heads,
84
+ dropout=0.0,
85
+ bias=True,
86
+ add_bias_kv=False,
87
+ add_zero_attn=False,
88
+ kdim=None,
89
+ vdim=None,
90
+ batch_first=False,
91
+ linear1_cls=Linear,
92
+ linear2_cls=Linear,
93
+ device=None,
94
+ dtype=None,
95
+ ) -> None:
96
+ factory_kwargs = {"device": device, "dtype": dtype}
97
+ super(MultiheadAttention, self).__init__()
98
+ self.embed_dim = embed_dim
99
+ self.kdim = kdim if kdim is not None else embed_dim
100
+ self.vdim = vdim if vdim is not None else embed_dim
101
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
102
+
103
+ self.num_heads = num_heads
104
+ self.dropout = dropout
105
+ self.batch_first = batch_first
106
+ self.head_dim = embed_dim // num_heads
107
+ assert (
108
+ self.head_dim * num_heads == self.embed_dim
109
+ ), "embed_dim must be divisible by num_heads"
110
+
111
+ if add_bias_kv:
112
+ self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
113
+ self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
114
+ else:
115
+ self.bias_k = self.bias_v = None
116
+
117
+ if linear1_cls == Linear:
118
+ if not self._qkv_same_embed_dim:
119
+ self.q_proj_weight = Parameter(
120
+ torch.empty((embed_dim, embed_dim), **factory_kwargs)
121
+ )
122
+ self.k_proj_weight = Parameter(
123
+ torch.empty((embed_dim, self.kdim), **factory_kwargs)
124
+ )
125
+ self.v_proj_weight = Parameter(
126
+ torch.empty((embed_dim, self.vdim), **factory_kwargs)
127
+ )
128
+ self.register_parameter("in_proj_weight", None)
129
+ else:
130
+ self.in_proj_weight = Parameter(
131
+ torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
132
+ )
133
+ self.register_parameter("q_proj_weight", None)
134
+ self.register_parameter("k_proj_weight", None)
135
+ self.register_parameter("v_proj_weight", None)
136
+
137
+ if bias:
138
+ self.in_proj_bias = Parameter(
139
+ torch.empty(3 * embed_dim, **factory_kwargs)
140
+ )
141
+ else:
142
+ self.register_parameter("in_proj_bias", None)
143
+ self.out_proj = NonDynamicallyQuantizableLinear(
144
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
145
+ )
146
+
147
+ self._reset_parameters()
148
+ else:
149
+ if not self._qkv_same_embed_dim:
150
+ raise NotImplementedError
151
+ else:
152
+ self.in_proj_linear = linear1_cls(
153
+ embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
154
+ )
155
+ self.in_proj_weight = self.in_proj_linear.weight
156
+
157
+ self.register_parameter("q_proj_weight", None)
158
+ self.register_parameter("k_proj_weight", None)
159
+ self.register_parameter("v_proj_weight", None)
160
+
161
+ if bias:
162
+ self.in_proj_bias = self.in_proj_linear.bias
163
+ else:
164
+ self.register_parameter("in_proj_bias", None)
165
+
166
+ self.out_proj = linear2_cls(
167
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
168
+ )
169
+
170
+ if self.bias_k is not None:
171
+ xavier_normal_(self.bias_k)
172
+ if self.bias_v is not None:
173
+ xavier_normal_(self.bias_v)
174
+
175
+ self.add_zero_attn = add_zero_attn
176
+
177
+ def _reset_parameters(self):
178
+ if self._qkv_same_embed_dim:
179
+ xavier_uniform_(self.in_proj_weight)
180
+ else:
181
+ xavier_uniform_(self.q_proj_weight)
182
+ xavier_uniform_(self.k_proj_weight)
183
+ xavier_uniform_(self.v_proj_weight)
184
+
185
+ if self.in_proj_bias is not None:
186
+ constant_(self.in_proj_bias, 0.0)
187
+ constant_(self.out_proj.bias, 0.0)
188
+
189
+ if self.bias_k is not None:
190
+ xavier_normal_(self.bias_k)
191
+ if self.bias_v is not None:
192
+ xavier_normal_(self.bias_v)
193
+
194
+ def __setstate__(self, state):
195
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
196
+ if "_qkv_same_embed_dim" not in state:
197
+ state["_qkv_same_embed_dim"] = True
198
+
199
+ super(MultiheadAttention, self).__setstate__(state)
200
+
201
+ def forward(
202
+ self,
203
+ query: Tensor,
204
+ key: Tensor,
205
+ value: Tensor,
206
+ key_padding_mask: Optional[Tensor] = None,
207
+ need_weights: bool = True,
208
+ attn_mask: Optional[Tensor] = None,
209
+ average_attn_weights: bool = True,
210
+ cache=None,
211
+ ) -> Tuple[Tensor, Optional[Tensor]]:
212
+ r"""
213
+ Args:
214
+ query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
215
+ or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
216
+ :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
217
+ Queries are compared against key-value pairs to produce the output.
218
+ See "Attention Is All You Need" for more details.
219
+ key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
220
+ or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
221
+ :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
222
+ See "Attention Is All You Need" for more details.
223
+ value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
224
+ ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
225
+ sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
226
+ See "Attention Is All You Need" for more details.
227
+ key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
228
+ to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
229
+ Binary and byte masks are supported.
230
+ For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
231
+ the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
232
+ need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
233
+ Default: ``True``.
234
+ attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
235
+ :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
236
+ :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
237
+ broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
238
+ Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
239
+ corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
240
+ corresponding position is not allowed to attend. For a float mask, the mask values will be added to
241
+ the attention weight.
242
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
243
+ heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
244
+ effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
245
+
246
+ Outputs:
247
+ - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
248
+ :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
249
+ where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
250
+ embedding dimension ``embed_dim``.
251
+ - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
252
+ returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
253
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
254
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
255
+ head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
256
+
257
+ .. note::
258
+ `batch_first` argument is ignored for unbatched inputs.
259
+ """
260
+ is_batched = query.dim() == 3
261
+ if key_padding_mask is not None:
262
+ _kpm_dtype = key_padding_mask.dtype
263
+ if _kpm_dtype != torch.bool and not torch.is_floating_point(
264
+ key_padding_mask
265
+ ):
266
+ raise AssertionError(
267
+ "only bool and floating types of key_padding_mask are supported"
268
+ )
269
+ why_not_fast_path = ""
270
+ if not is_batched:
271
+ why_not_fast_path = (
272
+ f"input not batched; expected query.dim() of 3 but got {query.dim()}"
273
+ )
274
+ elif query is not key or key is not value:
275
+ # When lifting this restriction, don't forget to either
276
+ # enforce that the dtypes all match or test cases where
277
+ # they don't!
278
+ why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
279
+ elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
280
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
281
+ elif (
282
+ self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
283
+ ):
284
+ # this case will fail anyway, but at least they'll get a useful error message.
285
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
286
+ elif self.training:
287
+ why_not_fast_path = "training is enabled"
288
+ elif not self.batch_first:
289
+ why_not_fast_path = "batch_first was not True"
290
+ elif self.bias_k is not None:
291
+ why_not_fast_path = "self.bias_k was not None"
292
+ elif self.bias_v is not None:
293
+ why_not_fast_path = "self.bias_v was not None"
294
+ elif self.dropout:
295
+ why_not_fast_path = f"dropout was {self.dropout}, required zero"
296
+ elif self.add_zero_attn:
297
+ why_not_fast_path = "add_zero_attn was enabled"
298
+ elif not self._qkv_same_embed_dim:
299
+ why_not_fast_path = "_qkv_same_embed_dim was not True"
300
+ elif attn_mask is not None:
301
+ why_not_fast_path = "attn_mask was not None"
302
+ elif query.is_nested and key_padding_mask is not None:
303
+ why_not_fast_path = (
304
+ "key_padding_mask is not supported with NestedTensor input"
305
+ )
306
+ elif self.num_heads % 2 == 1:
307
+ why_not_fast_path = "num_heads is odd"
308
+ elif torch.is_autocast_enabled():
309
+ why_not_fast_path = "autocast is enabled"
310
+
311
+ if not why_not_fast_path:
312
+ tensor_args = (
313
+ query,
314
+ key,
315
+ value,
316
+ self.in_proj_weight,
317
+ self.in_proj_bias,
318
+ self.out_proj.weight,
319
+ self.out_proj.bias,
320
+ )
321
+ # We have to use list comprehensions below because TorchScript does not support
322
+ # generator expressions.
323
+ if torch.overrides.has_torch_function(tensor_args):
324
+ why_not_fast_path = "some Tensor argument has_torch_function"
325
+ elif not all(
326
+ [
327
+ (x is None or x.is_cuda or "cpu" in str(x.device))
328
+ for x in tensor_args
329
+ ]
330
+ ):
331
+ why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
332
+ elif torch.is_grad_enabled() and any(
333
+ [x is not None and x.requires_grad for x in tensor_args]
334
+ ):
335
+ why_not_fast_path = (
336
+ "grad is enabled and at least one of query or the "
337
+ "input/output projection weights or biases requires_grad"
338
+ )
339
+ if not why_not_fast_path:
340
+ return torch._native_multi_head_attention(
341
+ query,
342
+ key,
343
+ value,
344
+ self.embed_dim,
345
+ self.num_heads,
346
+ self.in_proj_weight,
347
+ self.in_proj_bias,
348
+ self.out_proj.weight,
349
+ self.out_proj.bias,
350
+ key_padding_mask if key_padding_mask is not None else attn_mask,
351
+ need_weights,
352
+ average_attn_weights,
353
+ 1
354
+ if key_padding_mask is not None
355
+ else 0
356
+ if attn_mask is not None
357
+ else None,
358
+ )
359
+
360
+ any_nested = query.is_nested or key.is_nested or value.is_nested
361
+ assert not any_nested, (
362
+ "MultiheadAttention does not support NestedTensor outside of its fast path. "
363
+ + f"The fast path was not hit because {why_not_fast_path}"
364
+ )
365
+
366
+ if self.batch_first and is_batched:
367
+ # make sure that the transpose op does not affect the "is" property
368
+ if key is value:
369
+ if query is key:
370
+ query = key = value = query.transpose(1, 0)
371
+ else:
372
+ query, key = [x.transpose(1, 0) for x in (query, key)]
373
+ value = key
374
+ else:
375
+ query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
376
+
377
+ if not self._qkv_same_embed_dim:
378
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
379
+ query,
380
+ key,
381
+ value,
382
+ self.embed_dim,
383
+ self.num_heads,
384
+ self.in_proj_weight,
385
+ self.in_proj_bias,
386
+ self.bias_k,
387
+ self.bias_v,
388
+ self.add_zero_attn,
389
+ self.dropout,
390
+ self.out_proj.weight,
391
+ self.out_proj.bias,
392
+ training=self.training,
393
+ key_padding_mask=key_padding_mask,
394
+ need_weights=need_weights,
395
+ attn_mask=attn_mask,
396
+ use_separate_proj_weight=True,
397
+ q_proj_weight=self.q_proj_weight,
398
+ k_proj_weight=self.k_proj_weight,
399
+ v_proj_weight=self.v_proj_weight,
400
+ average_attn_weights=average_attn_weights,
401
+ cache=cache,
402
+ )
403
+ else:
404
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
405
+ query,
406
+ key,
407
+ value,
408
+ self.embed_dim,
409
+ self.num_heads,
410
+ self.in_proj_weight,
411
+ self.in_proj_bias,
412
+ self.bias_k,
413
+ self.bias_v,
414
+ self.add_zero_attn,
415
+ self.dropout,
416
+ self.out_proj.weight,
417
+ self.out_proj.bias,
418
+ training=self.training,
419
+ key_padding_mask=key_padding_mask,
420
+ need_weights=need_weights,
421
+ attn_mask=attn_mask,
422
+ average_attn_weights=average_attn_weights,
423
+ cache=cache,
424
+ )
425
+ if self.batch_first and is_batched:
426
+ return attn_output.transpose(1, 0), attn_output_weights
427
+ else:
428
+ return attn_output, attn_output_weights
AR/modules/embedding.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
2
+ import math
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+
8
+ class TokenEmbedding(nn.Module):
9
+ def __init__(
10
+ self,
11
+ embedding_dim: int,
12
+ vocab_size: int,
13
+ dropout: float = 0.0,
14
+ ):
15
+ super().__init__()
16
+
17
+ self.vocab_size = vocab_size
18
+ self.embedding_dim = embedding_dim
19
+
20
+ self.dropout = torch.nn.Dropout(p=dropout)
21
+ self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
22
+
23
+ @property
24
+ def weight(self) -> torch.Tensor:
25
+ return self.word_embeddings.weight
26
+
27
+ def embedding(self, index: int) -> torch.Tensor:
28
+ return self.word_embeddings.weight[index : index + 1]
29
+
30
+ def forward(self, x: torch.Tensor):
31
+ x = self.word_embeddings(x)
32
+ x = self.dropout(x)
33
+ return x
34
+
35
+
36
+ class SinePositionalEmbedding(nn.Module):
37
+ def __init__(
38
+ self,
39
+ embedding_dim: int,
40
+ dropout: float = 0.0,
41
+ scale: bool = False,
42
+ alpha: bool = False,
43
+ ):
44
+ super().__init__()
45
+ self.embedding_dim = embedding_dim
46
+ self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
47
+ self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
48
+ self.dropout = torch.nn.Dropout(p=dropout)
49
+
50
+ self.reverse = False
51
+ self.pe = None
52
+ self.extend_pe(torch.tensor(0.0).expand(1, 4000))
53
+
54
+ def extend_pe(self, x):
55
+ """Reset the positional encodings."""
56
+ if self.pe is not None:
57
+ if self.pe.size(1) >= x.size(1):
58
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
59
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
60
+ return
61
+ pe = torch.zeros(x.size(1), self.embedding_dim)
62
+ if self.reverse:
63
+ position = torch.arange(
64
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
65
+ ).unsqueeze(1)
66
+ else:
67
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
68
+ div_term = torch.exp(
69
+ torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
70
+ * -(math.log(10000.0) / self.embedding_dim)
71
+ )
72
+ pe[:, 0::2] = torch.sin(position * div_term)
73
+ pe[:, 1::2] = torch.cos(position * div_term)
74
+ pe = pe.unsqueeze(0)
75
+ self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ self.extend_pe(x)
79
+ output = x.unsqueeze(-1) if x.ndim == 2 else x
80
+ output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
81
+ return self.dropout(output)
AR/modules/lr_schedulers.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/lr_schedulers.py
2
+ import math
3
+
4
+ import torch
5
+ from matplotlib import pyplot as plt
6
+ from torch import nn
7
+ from torch.optim import Adam
8
+
9
+
10
+ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
11
+ """
12
+ Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ optimizer,
18
+ init_lr,
19
+ peak_lr,
20
+ end_lr,
21
+ warmup_steps=10000,
22
+ total_steps=400000,
23
+ current_step=0,
24
+ ):
25
+ self.init_lr = init_lr
26
+ self.peak_lr = peak_lr
27
+ self.end_lr = end_lr
28
+ self.optimizer = optimizer
29
+ self._warmup_rate = (peak_lr - init_lr) / warmup_steps
30
+ self._decay_rate = (end_lr - peak_lr) / (total_steps - warmup_steps)
31
+ self._current_step = current_step
32
+ self.lr = init_lr
33
+ self.warmup_steps = warmup_steps
34
+ self.total_steps = total_steps
35
+ self._last_lr = [self.lr]
36
+
37
+ def set_lr(self, lr):
38
+ self._last_lr = [g["lr"] for g in self.optimizer.param_groups]
39
+ for g in self.optimizer.param_groups:
40
+ # g['lr'] = lr
41
+ g["lr"] = self.end_lr ###锁定用线性
42
+
43
+ def step(self):
44
+ if self._current_step < self.warmup_steps:
45
+ lr = self.init_lr + self._warmup_rate * self._current_step
46
+
47
+ elif self._current_step > self.total_steps:
48
+ lr = self.end_lr
49
+
50
+ else:
51
+ decay_ratio = (self._current_step - self.warmup_steps) / (
52
+ self.total_steps - self.warmup_steps
53
+ )
54
+ if decay_ratio < 0.0 or decay_ratio > 1.0:
55
+ raise RuntimeError(
56
+ "Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings."
57
+ )
58
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
59
+ lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
60
+
61
+ self.lr = lr = self.end_lr = 0.002 ###锁定用线性###不听话,直接锁定!
62
+ self.set_lr(lr)
63
+ self.lr = lr
64
+ self._current_step += 1
65
+ return self.lr
66
+
67
+
68
+ if __name__ == "__main__":
69
+ m = nn.Linear(10, 10)
70
+ opt = Adam(m.parameters(), lr=1e-4)
71
+ s = WarmupCosineLRSchedule(
72
+ opt, 1e-6, 2e-4, 1e-6, warmup_steps=2000, total_steps=20000, current_step=0
73
+ )
74
+ lrs = []
75
+ for i in range(25000):
76
+ s.step()
77
+ lrs.append(s.lr)
78
+ print(s.lr)
79
+
80
+ plt.plot(lrs)
81
+ plt.plot(range(0, 25000), lrs)
82
+ plt.show()
AR/modules/optim.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
+ #
3
+ # See ../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import contextlib
17
+ import logging
18
+ from collections import defaultdict
19
+ from typing import List
20
+ from typing import Tuple
21
+
22
+ import torch
23
+ from torch import Tensor
24
+ from torch.optim import Optimizer
25
+
26
+
27
+ class BatchedOptimizer(Optimizer):
28
+ """
29
+ This class adds to class Optimizer the capability to optimize parameters in batches:
30
+ it will stack the parameters and their grads for you so the optimizer can work
31
+ on tensors with an extra leading dimension. This is intended for speed with GPUs,
32
+ as it reduces the number of kernels launched in the optimizer.
33
+
34
+ Args:
35
+ params:
36
+ """
37
+
38
+ def __init__(self, params, defaults):
39
+ super(BatchedOptimizer, self).__init__(params, defaults)
40
+
41
+ @contextlib.contextmanager
42
+ def batched_params(self, param_group, group_params_names):
43
+ """
44
+ This function returns (technically, yields) a list of
45
+ of tuples (p, state), where
46
+ p is a `fake` parameter that is stacked (over axis 0) from real parameters
47
+ that share the same shape, and its gradient is also stacked;
48
+ `state` is the state corresponding to this batch of parameters
49
+ (it will be physically located in the "state" for one of the real
50
+ parameters, the last one that has any particular shape and dtype).
51
+
52
+ This function is decorated as a context manager so that it can
53
+ write parameters back to their "real" locations.
54
+
55
+ The idea is, instead of doing:
56
+ <code>
57
+ for p in group["params"]:
58
+ state = self.state[p]
59
+ ...
60
+ </code>
61
+ you can do:
62
+ <code>
63
+ with self.batched_params(group["params"]) as batches:
64
+ for p, state, p_names in batches:
65
+ ...
66
+ </code>
67
+
68
+ Args:
69
+ group: a parameter group, which is a list of parameters; should be
70
+ one of self.param_groups.
71
+ group_params_names: name for each parameter in group,
72
+ which is List[str].
73
+ """
74
+ batches = defaultdict(
75
+ list
76
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
77
+ batches_names = defaultdict(
78
+ list
79
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
80
+
81
+ assert len(param_group) == len(group_params_names)
82
+ for p, named_p in zip(param_group, group_params_names):
83
+ key = (str(p.dtype), *p.shape)
84
+ batches[key].append(p)
85
+ batches_names[key].append(named_p)
86
+
87
+ batches_names_keys = list(batches_names.keys())
88
+ sorted_idx = sorted(
89
+ range(len(batches_names)), key=lambda i: batches_names_keys[i])
90
+ batches_names = [
91
+ batches_names[batches_names_keys[idx]] for idx in sorted_idx
92
+ ]
93
+ batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
94
+
95
+ stacked_params_dict = dict()
96
+
97
+ # turn batches into a list, in deterministic order.
98
+ # tuples will contain tuples of (stacked_param, state, stacked_params_names),
99
+ # one for each batch in `batches`.
100
+ tuples = []
101
+
102
+ for batch, batch_names in zip(batches, batches_names):
103
+ p = batch[0]
104
+ # we arbitrarily store the state in the
105
+ # state corresponding to the 1st parameter in the
106
+ # group. class Optimizer will take care of saving/loading state.
107
+ state = self.state[p]
108
+ p_stacked = torch.stack(batch)
109
+ grad = torch.stack([
110
+ torch.zeros_like(p) if p.grad is None else p.grad for p in batch
111
+ ])
112
+ p_stacked.grad = grad
113
+ stacked_params_dict[key] = p_stacked
114
+ tuples.append((p_stacked, state, batch_names))
115
+
116
+ yield tuples # <-- calling code will do the actual optimization here!
117
+
118
+ for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
119
+ for i, p in enumerate(batch): # batch is list of Parameter
120
+ p.copy_(stacked_params[i])
121
+
122
+
123
+ class ScaledAdam(BatchedOptimizer):
124
+ """
125
+ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
126
+ proportional to the norm of that parameter; and also learn the scale of the parameter,
127
+ in log space, subject to upper and lower limits (as if we had factored each parameter as
128
+ param = underlying_param * log_scale.exp())
129
+
130
+
131
+ Args:
132
+ params: The parameters or param_groups to optimize (like other Optimizer subclasses)
133
+ lr: The learning rate. We will typically use a learning rate schedule that starts
134
+ at 0.03 and decreases over time, i.e. much higher than other common
135
+ optimizers.
136
+ clipping_scale: (e.g. 2.0)
137
+ A scale for gradient-clipping: if specified, the normalized gradients
138
+ over the whole model will be clipped to have 2-norm equal to
139
+ `clipping_scale` times the median 2-norm over the most recent period
140
+ of `clipping_update_period` minibatches. By "normalized gradients",
141
+ we mean after multiplying by the rms parameter value for this tensor
142
+ [for non-scalars]; this is appropriate because our update is scaled
143
+ by this quantity.
144
+ betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
145
+ Must satisfy 0 < beta <= beta2 < 1.
146
+ scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
147
+ scale of each parameter tensor and scalar parameters of the mode..
148
+ If each parameter were decomposed
149
+ as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
150
+ would be a the scaling factor on the learning rate of p_scale.
151
+ eps: A general-purpose epsilon to prevent division by zero
152
+ param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
153
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
154
+ parameter tensor to be >= this value)
155
+ param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
156
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
157
+ parameter tensor to be <= this value)
158
+ scalar_max: Maximum absolute value for scalar parameters (applicable if your
159
+ model has any parameters with numel() == 1).
160
+ size_update_period: The periodicity, in steps, with which we update the size (scale)
161
+ of the parameter tensor. This is provided to save a little time
162
+ in the update.
163
+ clipping_update_period: if clipping_scale is specified, this is the period
164
+ """
165
+
166
+ def __init__(
167
+ self,
168
+ params,
169
+ lr=3e-02,
170
+ clipping_scale=None,
171
+ betas=(0.9, 0.98),
172
+ scalar_lr_scale=0.1,
173
+ eps=1.0e-08,
174
+ param_min_rms=1.0e-05,
175
+ param_max_rms=3.0,
176
+ scalar_max=10.0,
177
+ size_update_period=4,
178
+ clipping_update_period=100,
179
+ parameters_names=None,
180
+ show_dominant_parameters=True, ):
181
+
182
+ assert parameters_names is not None, (
183
+ "Please prepare parameters_names,"
184
+ "which is a List[List[str]]. Each List[str] is for a group"
185
+ "and each str is for a parameter")
186
+ defaults = dict(
187
+ lr=lr,
188
+ clipping_scale=clipping_scale,
189
+ betas=betas,
190
+ scalar_lr_scale=scalar_lr_scale,
191
+ eps=eps,
192
+ param_min_rms=param_min_rms,
193
+ param_max_rms=param_max_rms,
194
+ scalar_max=scalar_max,
195
+ size_update_period=size_update_period,
196
+ clipping_update_period=clipping_update_period, )
197
+
198
+ super(ScaledAdam, self).__init__(params, defaults)
199
+ assert len(self.param_groups) == len(parameters_names)
200
+ self.parameters_names = parameters_names
201
+ self.show_dominant_parameters = show_dominant_parameters
202
+
203
+ def __setstate__(self, state):
204
+ super(ScaledAdam, self).__setstate__(state)
205
+
206
+ @torch.no_grad()
207
+ def step(self, closure=None):
208
+ """Performs a single optimization step.
209
+
210
+ Arguments:
211
+ closure (callable, optional): A closure that reevaluates the model
212
+ and returns the loss.
213
+ """
214
+ loss = None
215
+ if closure is not None:
216
+ with torch.enable_grad():
217
+ loss = closure()
218
+
219
+ batch = True
220
+
221
+ for group, group_params_names in zip(self.param_groups,
222
+ self.parameters_names):
223
+
224
+ with self.batched_params(group["params"],
225
+ group_params_names) as batches:
226
+
227
+ # batches is list of pairs (stacked_param, state). stacked_param is like
228
+ # a regular parameter, and will have a .grad, but the 1st dim corresponds to
229
+ # a stacking dim, it is not a real dim.
230
+
231
+ if (len(batches[0][1]) ==
232
+ 0): # if len(first state) == 0: not yet initialized
233
+ clipping_scale = 1
234
+ else:
235
+ clipping_scale = self._get_clipping_scale(group, batches)
236
+
237
+ for p, state, _ in batches:
238
+ # Perform optimization step.
239
+ # grad is not going to be None, we handled that when creating the batches.
240
+ grad = p.grad
241
+ if grad.is_sparse:
242
+ raise RuntimeError(
243
+ "ScaledAdam optimizer does not support sparse gradients"
244
+ )
245
+ # State initialization
246
+ if len(state) == 0:
247
+ self._init_state(group, p, state)
248
+
249
+ self._step_one_batch(group, p, state, clipping_scale)
250
+
251
+ return loss
252
+
253
+ def _init_state(self, group: dict, p: Tensor, state: dict):
254
+ """
255
+ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
256
+ is actually the batch dimension, corresponding to batched-together
257
+ parameters of a given shape.
258
+
259
+
260
+ Args:
261
+ group: Dict to look up configuration values.
262
+ p: The parameter that we are initializing the state for
263
+ state: Dict from string to whatever state we are initializing
264
+ """
265
+ size_update_period = group["size_update_period"]
266
+
267
+ state["step"] = 0
268
+
269
+ kwargs = {"device": p.device, "dtype": p.dtype}
270
+
271
+ # 'delta' implements conventional momentum. There are
272
+ # several different kinds of update going on, so rather than
273
+ # compute "exp_avg" like in Adam, we store and decay a
274
+ # parameter-change "delta", which combines all forms of
275
+ # update. this is equivalent to how it's done in Adam,
276
+ # except for the first few steps.
277
+ state["delta"] = torch.zeros_like(
278
+ p, memory_format=torch.preserve_format)
279
+
280
+ batch_size = p.shape[0]
281
+ numel = p.numel() // batch_size
282
+ numel = p.numel()
283
+
284
+ if numel > 1:
285
+ # "param_rms" just periodically records the scalar root-mean-square value of
286
+ # the parameter tensor.
287
+ # it has a shape like (batch_size, 1, 1, 1, 1)
288
+ param_rms = (
289
+ (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
290
+ state["param_rms"] = param_rms
291
+
292
+ state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
293
+ state["scale_grads"] = torch.zeros(size_update_period,
294
+ *param_rms.shape, **kwargs)
295
+
296
+ # exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
297
+ state["exp_avg_sq"] = torch.zeros_like(
298
+ p, memory_format=torch.preserve_format)
299
+
300
+ def _get_clipping_scale(self,
301
+ group: dict,
302
+ tuples: List[Tuple[Tensor, dict, List[str]]]
303
+ ) -> float:
304
+ """
305
+ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
306
+ by this amount before applying the rest of the update.
307
+
308
+ Args:
309
+ group: the parameter group, an item in self.param_groups
310
+ tuples: a list of tuples of (param, state, param_names)
311
+ where param is a batched set of parameters,
312
+ with a .grad (1st dim is batch dim)
313
+ and state is the state-dict where optimization parameters are kept.
314
+ param_names is a List[str] while each str is name for a parameter
315
+ in batched set of parameters "param".
316
+ """
317
+ assert len(tuples) >= 1
318
+ clipping_scale = group["clipping_scale"]
319
+ (first_p, first_state, _) = tuples[0]
320
+ step = first_state["step"]
321
+ if clipping_scale is None or step == 0:
322
+ # no clipping. return early on step == 0 because the other
323
+ # parameters' state won't have been initialized yet.
324
+ return 1.0
325
+ clipping_update_period = group["clipping_update_period"]
326
+
327
+ tot_sumsq = torch.tensor(0.0, device=first_p.device)
328
+ for (p, state, param_names) in tuples:
329
+ grad = p.grad
330
+ if grad.is_sparse:
331
+ raise RuntimeError(
332
+ "ScaledAdam optimizer does not support sparse gradients")
333
+ if p.numel() == p.shape[0]: # a batch of scalars
334
+ tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
335
+ else:
336
+ tot_sumsq += ((grad * state["param_rms"])**2).sum()
337
+
338
+ tot_norm = tot_sumsq.sqrt()
339
+ if "model_norms" not in first_state:
340
+ first_state["model_norms"] = torch.zeros(
341
+ clipping_update_period, device=p.device)
342
+ first_state["model_norms"][step % clipping_update_period] = tot_norm
343
+
344
+ if step % clipping_update_period == 0:
345
+ # Print some stats.
346
+ # We don't reach here if step == 0 because we would have returned
347
+ # above.
348
+ sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
349
+ quartiles = []
350
+ for n in range(0, 5):
351
+ index = min(
352
+ clipping_update_period - 1,
353
+ (clipping_update_period // 4) * n, )
354
+ quartiles.append(sorted_norms[index].item())
355
+
356
+ median = quartiles[2]
357
+ threshold = clipping_scale * median
358
+ first_state["model_norm_threshold"] = threshold
359
+ percent_clipped = (first_state["num_clipped"] * 100.0 /
360
+ clipping_update_period
361
+ if "num_clipped" in first_state else 0.0)
362
+ first_state["num_clipped"] = 0
363
+ quartiles = " ".join(["%.3e" % x for x in quartiles])
364
+ logging.info(
365
+ f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
366
+ f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
367
+ )
368
+
369
+ if step < clipping_update_period:
370
+ return 1.0 # We have not yet estimated a norm to clip to.
371
+ else:
372
+ try:
373
+ model_norm_threshold = first_state["model_norm_threshold"]
374
+ except KeyError:
375
+ logging.info(
376
+ "Warning: model_norm_threshold not in state: possibly "
377
+ "you changed config when restarting, adding clipping_scale option?"
378
+ )
379
+ return 1.0
380
+ ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
381
+ if ans < 1.0:
382
+ first_state["num_clipped"] += 1
383
+ if ans < 0.1:
384
+ logging.warn(
385
+ f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
386
+ )
387
+ if self.show_dominant_parameters:
388
+ assert p.shape[0] == len(param_names)
389
+ self._show_gradient_dominating_parameter(tuples, tot_sumsq)
390
+ return ans
391
+
392
+ def _show_gradient_dominating_parameter(
393
+ self, tuples: List[Tuple[Tensor, dict, List[str]]],
394
+ tot_sumsq: Tensor):
395
+ """
396
+ Show information of parameter wihch dominanting tot_sumsq.
397
+
398
+ Args:
399
+ tuples: a list of tuples of (param, state, param_names)
400
+ where param is a batched set of parameters,
401
+ with a .grad (1st dim is batch dim)
402
+ and state is the state-dict where optimization parameters are kept.
403
+ param_names is a List[str] while each str is name for a parameter
404
+ in batched set of parameters "param".
405
+ tot_sumsq: sumsq of all parameters. Though it's could be calculated
406
+ from tuples, we still pass it to save some time.
407
+ """
408
+ all_sumsq_orig = {}
409
+ for (p, state, batch_param_names) in tuples:
410
+ # p is a stacked batch parameters.
411
+ batch_grad = p.grad
412
+ if p.numel() == p.shape[0]: # a batch of scalars
413
+ batch_sumsq_orig = batch_grad**2
414
+ # Dummpy values used by following `zip` statement.
415
+ batch_rms_orig = torch.ones(p.shape[0])
416
+ else:
417
+ batch_rms_orig = state["param_rms"]
418
+ batch_sumsq_orig = ((batch_grad * batch_rms_orig)**2).sum(
419
+ dim=list(range(1, batch_grad.ndim)))
420
+
421
+ for name, sumsq_orig, rms, grad in zip(batch_param_names,
422
+ batch_sumsq_orig,
423
+ batch_rms_orig, batch_grad):
424
+
425
+ proportion_orig = sumsq_orig / tot_sumsq
426
+ all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
427
+
428
+ assert torch.isclose(
429
+ sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
430
+ torch.tensor(1.0), )
431
+ sorted_by_proportion = {
432
+ k: v
433
+ for k, v in sorted(
434
+ all_sumsq_orig.items(),
435
+ key=lambda item: item[1][0],
436
+ reverse=True, )
437
+ }
438
+ dominant_param_name = next(iter(sorted_by_proportion))
439
+ (dominant_proportion, dominant_sumsq, dominant_rms,
440
+ dominant_grad, ) = sorted_by_proportion[dominant_param_name]
441
+ logging.info(f"Parameter Dominanting tot_sumsq {dominant_param_name}"
442
+ f" with proportion {dominant_proportion:.2f},"
443
+ f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
444
+ f"={dominant_sumsq:.3e},"
445
+ f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
446
+ f" orig_rms_sq={(dominant_rms**2).item():.3e}")
447
+
448
+ def _step_one_batch(self,
449
+ group: dict,
450
+ p: Tensor,
451
+ state: dict,
452
+ clipping_scale: float):
453
+ """
454
+ Do the step for one parameter, which is actually going to be a batch of
455
+ `real` parameters, with dim 0 as the batch dim.
456
+ Args:
457
+ group: dict to look up configuration values
458
+ p: parameter to update (actually multiple parameters stacked together
459
+ as a batch)
460
+ state: state-dict for p, to look up the optimizer state
461
+ """
462
+ lr = group["lr"]
463
+ size_update_period = group["size_update_period"]
464
+ beta1 = group["betas"][0]
465
+
466
+ grad = p.grad
467
+ if clipping_scale != 1.0:
468
+ grad = grad * clipping_scale
469
+ step = state["step"]
470
+ delta = state["delta"]
471
+
472
+ delta.mul_(beta1)
473
+ batch_size = p.shape[0]
474
+ numel = p.numel() // batch_size
475
+ if numel > 1:
476
+ # Update the size/scale of p, and set param_rms
477
+ scale_grads = state["scale_grads"]
478
+ scale_grads[step % size_update_period] = (p * grad).sum(
479
+ dim=list(range(1, p.ndim)), keepdim=True)
480
+ if step % size_update_period == size_update_period - 1:
481
+ param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
482
+ param_rms.copy_((p**2)
483
+ .mean(dim=list(range(1, p.ndim)), keepdim=True)
484
+ .sqrt())
485
+ if step > 0:
486
+ # self._size_update() learns the overall scale on the
487
+ # parameter, by shrinking or expanding it.
488
+ self._size_update(group, scale_grads, p, state)
489
+
490
+ if numel == 1:
491
+ # For parameters with 1 element we just use regular Adam.
492
+ # Updates delta.
493
+ self._step_scalar(group, p, state)
494
+ else:
495
+ self._step(group, p, state)
496
+
497
+ state["step"] = step + 1
498
+
499
+ def _size_update(self,
500
+ group: dict,
501
+ scale_grads: Tensor,
502
+ p: Tensor,
503
+ state: dict) -> None:
504
+ """
505
+ Called only where p.numel() > 1, this updates the scale of the parameter.
506
+ If we imagine: p = underlying_param * scale.exp(), and we are doing
507
+ gradient descent on underlying param and on scale, this function does the update
508
+ on `scale`.
509
+
510
+ Args:
511
+ group: dict to look up configuration values
512
+ scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
513
+ grads w.r.t. the scales.
514
+ p: The parameter to update
515
+ state: The state-dict of p
516
+ """
517
+
518
+ param_rms = state["param_rms"]
519
+ beta1, beta2 = group["betas"]
520
+ size_lr = group["lr"] * group["scalar_lr_scale"]
521
+ param_min_rms = group["param_min_rms"]
522
+ param_max_rms = group["param_max_rms"]
523
+ eps = group["eps"]
524
+ step = state["step"]
525
+ batch_size = p.shape[0]
526
+
527
+ size_update_period = scale_grads.shape[0]
528
+ # correct beta2 for the size update period: we will have
529
+ # faster decay at this level.
530
+ beta2_corr = beta2**size_update_period
531
+
532
+ scale_exp_avg_sq = state[
533
+ "scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
534
+ scale_exp_avg_sq.mul_(beta2_corr).add_(
535
+ (scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
536
+ alpha=1 - beta2_corr, ) # shape is (batch_size, 1, 1, ...)
537
+
538
+ # The 1st time we reach here is when size_step == 1.
539
+ size_step = (step + 1) // size_update_period
540
+ bias_correction2 = 1 - beta2_corr**size_step
541
+ # we don't bother with bias_correction1; this will help prevent divergence
542
+ # at the start of training.
543
+
544
+ denom = scale_exp_avg_sq.sqrt() + eps
545
+
546
+ scale_step = (-size_lr * (bias_correction2**0.5) *
547
+ scale_grads.sum(dim=0) / denom)
548
+
549
+ is_too_small = param_rms < param_min_rms
550
+ is_too_large = param_rms > param_max_rms
551
+
552
+ # when the param gets too small, just don't shrink it any further.
553
+ scale_step.masked_fill_(is_too_small, 0.0)
554
+ # when it gets too large, stop it from getting any larger.
555
+ scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
556
+ delta = state["delta"]
557
+ # the factor of (1-beta1) relates to momentum.
558
+ delta.add_(p * scale_step, alpha=(1 - beta1))
559
+
560
+ def _step(self, group: dict, p: Tensor, state: dict):
561
+ """
562
+ This function does the core update of self.step(), in the case where the members of
563
+ the batch have more than 1 element.
564
+
565
+ Args:
566
+ group: A dict which will be used to look up configuration values
567
+ p: The parameter to be updated
568
+ grad: The grad of p
569
+ state: The state-dict corresponding to parameter p
570
+
571
+ This function modifies p.
572
+ """
573
+ grad = p.grad
574
+ lr = group["lr"]
575
+ beta1, beta2 = group["betas"]
576
+ eps = group["eps"]
577
+ param_min_rms = group["param_min_rms"]
578
+ step = state["step"]
579
+
580
+ exp_avg_sq = state["exp_avg_sq"]
581
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
582
+
583
+ this_step = state["step"] - (state["zero_step"]
584
+ if "zero_step" in state else 0)
585
+ bias_correction2 = 1 - beta2**(this_step + 1)
586
+ if bias_correction2 < 0.99:
587
+ # note: not in-place.
588
+ exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
589
+
590
+ denom = exp_avg_sq.sqrt()
591
+ denom += eps
592
+ grad = grad / denom
593
+
594
+ alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
595
+
596
+ delta = state["delta"]
597
+ delta.add_(grad * alpha)
598
+ p.add_(delta)
599
+
600
+ def _step_scalar(self, group: dict, p: Tensor, state: dict):
601
+ """
602
+ A simplified form of the core update for scalar tensors, where we cannot get a good
603
+ estimate of the parameter rms.
604
+ """
605
+ beta1, beta2 = group["betas"]
606
+ scalar_max = group["scalar_max"]
607
+ eps = group["eps"]
608
+ lr = group["lr"] * group["scalar_lr_scale"]
609
+ grad = p.grad
610
+
611
+ exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
612
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
613
+
614
+ # bias_correction2 is like in Adam. Don't bother with bias_correction1;
615
+ # slower update at the start will help stability anyway.
616
+ bias_correction2 = 1 - beta2**(state["step"] + 1)
617
+ denom = (exp_avg_sq / bias_correction2).sqrt() + eps
618
+
619
+ delta = state["delta"]
620
+ delta.add_(grad / denom, alpha=-lr * (1 - beta1))
621
+ p.clamp_(min=-scalar_max, max=scalar_max)
622
+ p.add_(delta)
AR/modules/patched_mha_with_cache.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn.functional import *
2
+ from torch.nn.functional import (
3
+ _mha_shape_check,
4
+ _canonical_mask,
5
+ _none_or_dtype,
6
+ _in_projection_packed,
7
+ )
8
+
9
+ # import torch
10
+ # Tensor = torch.Tensor
11
+ # from typing import Callable, List, Optional, Tuple, Union
12
+
13
+
14
+ def multi_head_attention_forward_patched(
15
+ query: Tensor,
16
+ key: Tensor,
17
+ value: Tensor,
18
+ embed_dim_to_check: int,
19
+ num_heads: int,
20
+ in_proj_weight: Optional[Tensor],
21
+ in_proj_bias: Optional[Tensor],
22
+ bias_k: Optional[Tensor],
23
+ bias_v: Optional[Tensor],
24
+ add_zero_attn: bool,
25
+ dropout_p: float,
26
+ out_proj_weight: Tensor,
27
+ out_proj_bias: Optional[Tensor],
28
+ training: bool = True,
29
+ key_padding_mask: Optional[Tensor] = None,
30
+ need_weights: bool = True,
31
+ attn_mask: Optional[Tensor] = None,
32
+ use_separate_proj_weight: bool = False,
33
+ q_proj_weight: Optional[Tensor] = None,
34
+ k_proj_weight: Optional[Tensor] = None,
35
+ v_proj_weight: Optional[Tensor] = None,
36
+ static_k: Optional[Tensor] = None,
37
+ static_v: Optional[Tensor] = None,
38
+ average_attn_weights: bool = True,
39
+ is_causal: bool = False,
40
+ cache=None,
41
+ ) -> Tuple[Tensor, Optional[Tensor]]:
42
+ r"""
43
+ Args:
44
+ query, key, value: map a query and a set of key-value pairs to an output.
45
+ See "Attention Is All You Need" for more details.
46
+ embed_dim_to_check: total dimension of the model.
47
+ num_heads: parallel attention heads.
48
+ in_proj_weight, in_proj_bias: input projection weight and bias.
49
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
50
+ add_zero_attn: add a new batch of zeros to the key and
51
+ value sequences at dim=1.
52
+ dropout_p: probability of an element to be zeroed.
53
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
54
+ training: apply dropout if is ``True``.
55
+ key_padding_mask: if provided, specified padding elements in the key will
56
+ be ignored by the attention. This is an binary mask. When the value is True,
57
+ the corresponding value on the attention layer will be filled with -inf.
58
+ need_weights: output attn_output_weights.
59
+ Default: `True`
60
+ Note: `needs_weight` defaults to `True`, but should be set to `False`
61
+ For best performance when attention weights are not nedeeded.
62
+ *Setting needs_weights to `True`
63
+ leads to a significant performance degradation.*
64
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
65
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
66
+ is_causal: If specified, applies a causal mask as attention mask, and ignores
67
+ attn_mask for computing scaled dot product attention.
68
+ Default: ``False``.
69
+ .. warning::
70
+ is_causal is provides a hint that the attn_mask is the
71
+ causal mask.Providing incorrect hints can result in
72
+ incorrect execution, including forward and backward
73
+ compatibility.
74
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
75
+ and value in different forms. If false, in_proj_weight will be used, which is
76
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
77
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
78
+ static_k, static_v: static key and value used for attention operators.
79
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
80
+ Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
81
+ when ``need_weights=True.``. Default: True
82
+
83
+
84
+ Shape:
85
+ Inputs:
86
+ - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
87
+ the embedding dimension.
88
+ - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
89
+ the embedding dimension.
90
+ - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
91
+ the embedding dimension.
92
+ - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
93
+ If a FloatTensor is provided, it will be directly added to the value.
94
+ If a BoolTensor is provided, the positions with the
95
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
96
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
97
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
98
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
99
+ positions. If a BoolTensor is provided, positions with ``True``
100
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
101
+ is provided, it will be added to the attention weight.
102
+ - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
103
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
104
+ - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
105
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
106
+
107
+ Outputs:
108
+ - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
109
+ E is the embedding dimension.
110
+ - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
111
+ attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
112
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
113
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
114
+ head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
115
+ """
116
+ tens_ops = (
117
+ query,
118
+ key,
119
+ value,
120
+ in_proj_weight,
121
+ in_proj_bias,
122
+ bias_k,
123
+ bias_v,
124
+ out_proj_weight,
125
+ out_proj_bias,
126
+ )
127
+ if has_torch_function(tens_ops):
128
+ return handle_torch_function(
129
+ multi_head_attention_forward,
130
+ tens_ops,
131
+ query,
132
+ key,
133
+ value,
134
+ embed_dim_to_check,
135
+ num_heads,
136
+ in_proj_weight,
137
+ in_proj_bias,
138
+ bias_k,
139
+ bias_v,
140
+ add_zero_attn,
141
+ dropout_p,
142
+ out_proj_weight,
143
+ out_proj_bias,
144
+ training=training,
145
+ key_padding_mask=key_padding_mask,
146
+ need_weights=need_weights,
147
+ attn_mask=attn_mask,
148
+ is_causal=is_causal,
149
+ use_separate_proj_weight=use_separate_proj_weight,
150
+ q_proj_weight=q_proj_weight,
151
+ k_proj_weight=k_proj_weight,
152
+ v_proj_weight=v_proj_weight,
153
+ static_k=static_k,
154
+ static_v=static_v,
155
+ average_attn_weights=average_attn_weights,
156
+ cache=cache,
157
+ )
158
+
159
+ is_batched = _mha_shape_check(
160
+ query, key, value, key_padding_mask, attn_mask, num_heads
161
+ )
162
+
163
+ # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
164
+ # is batched, run the computation and before returning squeeze the
165
+ # batch dimension so that the output doesn't carry this temporary batch dimension.
166
+ if not is_batched:
167
+ # unsqueeze if the input is unbatched
168
+ query = query.unsqueeze(1)
169
+ key = key.unsqueeze(1)
170
+ value = value.unsqueeze(1)
171
+ if key_padding_mask is not None:
172
+ key_padding_mask = key_padding_mask.unsqueeze(0)
173
+
174
+ # set up shape vars
175
+ tgt_len, bsz, embed_dim = query.shape
176
+ src_len, _, _ = key.shape
177
+
178
+ key_padding_mask = _canonical_mask(
179
+ mask=key_padding_mask,
180
+ mask_name="key_padding_mask",
181
+ other_type=_none_or_dtype(attn_mask),
182
+ other_name="attn_mask",
183
+ target_type=query.dtype,
184
+ )
185
+
186
+ if is_causal and attn_mask is None:
187
+ raise RuntimeError(
188
+ "Need attn_mask if specifying the is_causal hint. "
189
+ "You may use the Transformer module method "
190
+ "`generate_square_subsequent_mask` to create this mask."
191
+ )
192
+
193
+ if is_causal and key_padding_mask is None and not need_weights:
194
+ # when we have a kpm or need weights, we need attn_mask
195
+ # Otherwise, we use the is_causal hint go as is_causal
196
+ # indicator to SDPA.
197
+ attn_mask = None
198
+ else:
199
+ attn_mask = _canonical_mask(
200
+ mask=attn_mask,
201
+ mask_name="attn_mask",
202
+ other_type=None,
203
+ other_name="",
204
+ target_type=query.dtype,
205
+ check_other=False,
206
+ )
207
+
208
+ if key_padding_mask is not None:
209
+ # We have the attn_mask, and use that to merge kpm into it.
210
+ # Turn off use of is_causal hint, as the merged mask is no
211
+ # longer causal.
212
+ is_causal = False
213
+
214
+ assert (
215
+ embed_dim == embed_dim_to_check
216
+ ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
217
+ if isinstance(embed_dim, torch.Tensor):
218
+ # embed_dim can be a tensor when JIT tracing
219
+ head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
220
+ else:
221
+ head_dim = embed_dim // num_heads
222
+ assert (
223
+ head_dim * num_heads == embed_dim
224
+ ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
225
+ if use_separate_proj_weight:
226
+ # allow MHA to have different embedding dimensions when separate projection weights are used
227
+ assert (
228
+ key.shape[:2] == value.shape[:2]
229
+ ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
230
+ else:
231
+ assert (
232
+ key.shape == value.shape
233
+ ), f"key shape {key.shape} does not match value shape {value.shape}"
234
+
235
+ #
236
+ # compute in-projection
237
+ #
238
+ if not use_separate_proj_weight:
239
+ assert (
240
+ in_proj_weight is not None
241
+ ), "use_separate_proj_weight is False but in_proj_weight is None"
242
+ q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
243
+ else:
244
+ assert (
245
+ q_proj_weight is not None
246
+ ), "use_separate_proj_weight is True but q_proj_weight is None"
247
+ assert (
248
+ k_proj_weight is not None
249
+ ), "use_separate_proj_weight is True but k_proj_weight is None"
250
+ assert (
251
+ v_proj_weight is not None
252
+ ), "use_separate_proj_weight is True but v_proj_weight is None"
253
+ if in_proj_bias is None:
254
+ b_q = b_k = b_v = None
255
+ else:
256
+ b_q, b_k, b_v = in_proj_bias.chunk(3)
257
+ q, k, v = _in_projection(
258
+ query,
259
+ key,
260
+ value,
261
+ q_proj_weight,
262
+ k_proj_weight,
263
+ v_proj_weight,
264
+ b_q,
265
+ b_k,
266
+ b_v,
267
+ )
268
+ if cache != None:
269
+ if cache["first_infer"] == 1:
270
+ cache["k"][cache["stage"]] = k
271
+ # print(0,cache["k"].shape)
272
+ cache["v"][cache["stage"]] = v
273
+ else: ###12个layer每个都要留自己的cache_kv
274
+ # print(1,cache["k"].shape)
275
+ cache["k"][cache["stage"]] = torch.cat(
276
+ [cache["k"][cache["stage"]], k], 0
277
+ ) ##本来时序是1,但是proj的时候可能transpose了所以时序到0维了
278
+ cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]], v], 0)
279
+ # print(2, cache["k"].shape)
280
+ src_len = cache["k"][cache["stage"]].shape[0]
281
+ k = cache["k"][cache["stage"]]
282
+ v = cache["v"][cache["stage"]]
283
+ # if attn_mask is not None:
284
+ # attn_mask=attn_mask[-1:,]
285
+ # print(attn_mask.shape,attn_mask)
286
+ cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
287
+ # print(2333,cache)
288
+ # prep attention mask
289
+
290
+ attn_mask = _canonical_mask(
291
+ mask=attn_mask,
292
+ mask_name="attn_mask",
293
+ other_type=None,
294
+ other_name="",
295
+ target_type=q.dtype,
296
+ check_other=False,
297
+ )
298
+
299
+ if attn_mask is not None:
300
+ # ensure attn_mask's dim is 3
301
+ if attn_mask.dim() == 2:
302
+ correct_2d_size = (tgt_len, src_len)
303
+ if attn_mask.shape != correct_2d_size:
304
+ raise RuntimeError(
305
+ f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
306
+ )
307
+ attn_mask = attn_mask.unsqueeze(0)
308
+ elif attn_mask.dim() == 3:
309
+ correct_3d_size = (bsz * num_heads, tgt_len, src_len)
310
+ if attn_mask.shape != correct_3d_size:
311
+ raise RuntimeError(
312
+ f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
313
+ )
314
+ else:
315
+ raise RuntimeError(
316
+ f"attn_mask's dimension {attn_mask.dim()} is not supported"
317
+ )
318
+
319
+ # add bias along batch dimension (currently second)
320
+ if bias_k is not None and bias_v is not None:
321
+ assert static_k is None, "bias cannot be added to static key."
322
+ assert static_v is None, "bias cannot be added to static value."
323
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
324
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
325
+ if attn_mask is not None:
326
+ attn_mask = pad(attn_mask, (0, 1))
327
+ if key_padding_mask is not None:
328
+ key_padding_mask = pad(key_padding_mask, (0, 1))
329
+ else:
330
+ assert bias_k is None
331
+ assert bias_v is None
332
+
333
+ #
334
+ # reshape q, k, v for multihead attention and make em batch first
335
+ #
336
+ q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
337
+ if static_k is None:
338
+ k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
339
+ else:
340
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
341
+ assert (
342
+ static_k.size(0) == bsz * num_heads
343
+ ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
344
+ assert (
345
+ static_k.size(2) == head_dim
346
+ ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
347
+ k = static_k
348
+ if static_v is None:
349
+ v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
350
+ else:
351
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
352
+ assert (
353
+ static_v.size(0) == bsz * num_heads
354
+ ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
355
+ assert (
356
+ static_v.size(2) == head_dim
357
+ ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
358
+ v = static_v
359
+
360
+ # add zero attention along batch dimension (now first)
361
+ if add_zero_attn:
362
+ zero_attn_shape = (bsz * num_heads, 1, head_dim)
363
+ k = torch.cat(
364
+ [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
365
+ )
366
+ v = torch.cat(
367
+ [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
368
+ )
369
+ if attn_mask is not None:
370
+ attn_mask = pad(attn_mask, (0, 1))
371
+ if key_padding_mask is not None:
372
+ key_padding_mask = pad(key_padding_mask, (0, 1))
373
+
374
+ # update source sequence length after adjustments
375
+ src_len = k.size(1)
376
+
377
+ # merge key padding and attention masks
378
+ if key_padding_mask is not None:
379
+ assert key_padding_mask.shape == (
380
+ bsz,
381
+ src_len,
382
+ ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
383
+ key_padding_mask = (
384
+ key_padding_mask.view(bsz, 1, 1, src_len)
385
+ .expand(-1, num_heads, -1, -1)
386
+ .reshape(bsz * num_heads, 1, src_len)
387
+ )
388
+ if attn_mask is None:
389
+ attn_mask = key_padding_mask
390
+ else:
391
+ attn_mask = attn_mask + key_padding_mask
392
+
393
+ # adjust dropout probability
394
+ if not training:
395
+ dropout_p = 0.0
396
+
397
+ #
398
+ # (deep breath) calculate attention and out projection
399
+ #
400
+
401
+ if need_weights:
402
+ B, Nt, E = q.shape
403
+ q_scaled = q / math.sqrt(E)
404
+
405
+ assert not (
406
+ is_causal and attn_mask is None
407
+ ), "FIXME: is_causal not implemented for need_weights"
408
+
409
+ if attn_mask is not None:
410
+ attn_output_weights = torch.baddbmm(
411
+ attn_mask, q_scaled, k.transpose(-2, -1)
412
+ )
413
+ else:
414
+ attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
415
+ attn_output_weights = softmax(attn_output_weights, dim=-1)
416
+ if dropout_p > 0.0:
417
+ attn_output_weights = dropout(attn_output_weights, p=dropout_p)
418
+
419
+ attn_output = torch.bmm(attn_output_weights, v)
420
+
421
+ attn_output = (
422
+ attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
423
+ )
424
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
425
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
426
+
427
+ # optionally average attention weights over heads
428
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
429
+ if average_attn_weights:
430
+ attn_output_weights = attn_output_weights.mean(dim=1)
431
+
432
+ if not is_batched:
433
+ # squeeze the output if input was unbatched
434
+ attn_output = attn_output.squeeze(1)
435
+ attn_output_weights = attn_output_weights.squeeze(0)
436
+ return attn_output, attn_output_weights
437
+ else:
438
+ # attn_mask can be either (L,S) or (N*num_heads, L, S)
439
+ # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
440
+ # in order to match the input for SDPA of (N, num_heads, L, S)
441
+ if attn_mask is not None:
442
+ if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
443
+ attn_mask = attn_mask.unsqueeze(0)
444
+ else:
445
+ attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
446
+
447
+ q = q.view(bsz, num_heads, tgt_len, head_dim)
448
+ k = k.view(bsz, num_heads, src_len, head_dim)
449
+ v = v.view(bsz, num_heads, src_len, head_dim)
450
+
451
+ attn_output = scaled_dot_product_attention(
452
+ q, k, v, attn_mask, dropout_p, is_causal
453
+ )
454
+ attn_output = (
455
+ attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
456
+ )
457
+
458
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
459
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
460
+ if not is_batched:
461
+ # squeeze the output if input was unbatched
462
+ attn_output = attn_output.squeeze(1)
463
+ return attn_output, None
AR/modules/scaling.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
+ #
3
+ # See ../../../../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import logging
17
+ import math
18
+ import random
19
+ from typing import Optional
20
+ from typing import Tuple
21
+ from typing import Union
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ from torch import Tensor
26
+
27
+
28
+ class DoubleSwishFunction(torch.autograd.Function):
29
+ """
30
+ double_swish(x) = x * torch.sigmoid(x-1)
31
+ This is a definition, originally motivated by its close numerical
32
+ similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
33
+
34
+ Memory-efficient derivative computation:
35
+ double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
36
+ double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
37
+ Now, s'(x) = s(x) * (1-s(x)).
38
+ double_swish'(x) = x * s'(x) + s(x).
39
+ = x * s(x) * (1-s(x)) + s(x).
40
+ = double_swish(x) * (1-s(x)) + s(x)
41
+ ... so we just need to remember s(x) but not x itself.
42
+ """
43
+
44
+ @staticmethod
45
+ def forward(ctx, x: Tensor) -> Tensor:
46
+ requires_grad = x.requires_grad
47
+ x_dtype = x.dtype
48
+ if x.dtype == torch.float16:
49
+ x = x.to(torch.float32)
50
+
51
+ s = torch.sigmoid(x - 1.0)
52
+ y = x * s
53
+
54
+ if requires_grad:
55
+ deriv = y * (1 - s) + s
56
+ # notes on derivative of x * sigmoid(x - 1):
57
+ # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
58
+ # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
59
+ # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
60
+ # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
61
+ # floors), should be expectation-preserving.
62
+ floor = -0.043637
63
+ ceil = 1.2
64
+ d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
65
+ deriv
66
+ )
67
+ if __name__ == "__main__":
68
+ # for self-testing only.
69
+ assert d_scaled.min() >= 0.0
70
+ assert d_scaled.max() < 256.0
71
+ d_int = d_scaled.to(torch.uint8)
72
+ ctx.save_for_backward(d_int)
73
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
74
+ y = y.to(torch.float16)
75
+ return y
76
+
77
+ @staticmethod
78
+ def backward(ctx, y_grad: Tensor) -> Tensor:
79
+ (d,) = ctx.saved_tensors
80
+ # the same constants as used in forward pass.
81
+ floor = -0.043637
82
+ ceil = 1.2
83
+ d = d * ((ceil - floor) / 255.0) + floor
84
+ return y_grad * d
85
+
86
+
87
+ class DoubleSwish(torch.nn.Module):
88
+ def forward(self, x: Tensor) -> Tensor:
89
+ """Return double-swish activation function which is an approximation to Swish(Swish(x)),
90
+ that we approximate closely with x * sigmoid(x-1).
91
+ """
92
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
93
+ return x * torch.sigmoid(x - 1.0)
94
+ return DoubleSwishFunction.apply(x)
95
+
96
+
97
+ class ActivationBalancerFunction(torch.autograd.Function):
98
+ @staticmethod
99
+ def forward(
100
+ ctx,
101
+ x: Tensor,
102
+ scale_factor: Tensor,
103
+ sign_factor: Optional[Tensor],
104
+ channel_dim: int,
105
+ ) -> Tensor:
106
+ if channel_dim < 0:
107
+ channel_dim += x.ndim
108
+ ctx.channel_dim = channel_dim
109
+ xgt0 = x > 0
110
+ if sign_factor is None:
111
+ ctx.save_for_backward(xgt0, scale_factor)
112
+ else:
113
+ ctx.save_for_backward(xgt0, scale_factor, sign_factor)
114
+ return x
115
+
116
+ @staticmethod
117
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
118
+ if len(ctx.saved_tensors) == 3:
119
+ xgt0, scale_factor, sign_factor = ctx.saved_tensors
120
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
121
+ scale_factor = scale_factor.unsqueeze(-1)
122
+ sign_factor = sign_factor.unsqueeze(-1)
123
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
124
+ else:
125
+ xgt0, scale_factor = ctx.saved_tensors
126
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
127
+ scale_factor = scale_factor.unsqueeze(-1)
128
+ factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
129
+ neg_delta_grad = x_grad.abs() * factor
130
+ return (
131
+ x_grad - neg_delta_grad,
132
+ None,
133
+ None,
134
+ None,
135
+ )
136
+
137
+
138
+ def _compute_scale_factor(
139
+ x: Tensor,
140
+ channel_dim: int,
141
+ min_abs: float,
142
+ max_abs: float,
143
+ gain_factor: float,
144
+ max_factor: float,
145
+ ) -> Tensor:
146
+ if channel_dim < 0:
147
+ channel_dim += x.ndim
148
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
149
+ x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
150
+
151
+ if min_abs == 0.0:
152
+ below_threshold = 0.0
153
+ else:
154
+ # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
155
+ # x_abs)_mean , min_abs.
156
+ below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
157
+ min=0, max=max_factor
158
+ )
159
+
160
+ above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
161
+ min=0, max=max_factor
162
+ )
163
+
164
+ return below_threshold - above_threshold
165
+
166
+
167
+ def _compute_sign_factor(
168
+ x: Tensor,
169
+ channel_dim: int,
170
+ min_positive: float,
171
+ max_positive: float,
172
+ gain_factor: float,
173
+ max_factor: float,
174
+ ) -> Tensor:
175
+ if channel_dim < 0:
176
+ channel_dim += x.ndim
177
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
178
+ proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
179
+ if min_positive == 0.0:
180
+ factor1 = 0.0
181
+ else:
182
+ # 0 if proportion_positive >= min_positive, else can be
183
+ # as large as max_factor.
184
+ factor1 = (
185
+ (min_positive - proportion_positive) * (gain_factor / min_positive)
186
+ ).clamp_(min=0, max=max_factor)
187
+
188
+ if max_positive == 1.0:
189
+ factor2 = 0.0
190
+ else:
191
+ # 0 if self.proportion_positive <= max_positive, else can be
192
+ # as large as -max_factor.
193
+ factor2 = (
194
+ (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))
195
+ ).clamp_(min=0, max=max_factor)
196
+ sign_factor = factor1 - factor2
197
+ # require min_positive != 0 or max_positive != 1:
198
+ assert not isinstance(sign_factor, float)
199
+ return sign_factor
200
+
201
+
202
+ class ActivationBalancer(torch.nn.Module):
203
+ """
204
+ Modifies the backpropped derivatives of a function to try to encourage, for
205
+ each channel, that it is positive at least a proportion `threshold` of the
206
+ time. It does this by multiplying negative derivative values by up to
207
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
208
+ interpolated from 1 at the threshold to those extremal values when none
209
+ of the inputs are positive.
210
+
211
+ Args:
212
+ num_channels: the number of channels
213
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
214
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
215
+ min_positive: the minimum, per channel, of the proportion of the time
216
+ that (x > 0), below which we start to modify the derivatives.
217
+ max_positive: the maximum, per channel, of the proportion of the time
218
+ that (x > 0), above which we start to modify the derivatives.
219
+ max_factor: the maximum factor by which we modify the derivatives for
220
+ either the sign constraint or the magnitude constraint;
221
+ e.g. with max_factor=0.02, the the derivatives would be multiplied by
222
+ values in the range [0.98..1.02].
223
+ sign_gain_factor: determines the 'gain' with which we increase the
224
+ change in gradient once the constraints on min_positive and max_positive
225
+ are violated.
226
+ scale_gain_factor: determines the 'gain' with which we increase the
227
+ change in gradient once the constraints on min_abs and max_abs
228
+ are violated.
229
+ min_abs: the minimum average-absolute-value difference from the mean
230
+ value per channel, which we allow, before we start to modify
231
+ the derivatives to prevent this.
232
+ max_abs: the maximum average-absolute-value difference from the mean
233
+ value per channel, which we allow, before we start to modify
234
+ the derivatives to prevent this.
235
+ min_prob: determines the minimum probability with which we modify the
236
+ gradients for the {min,max}_positive and {min,max}_abs constraints,
237
+ on each forward(). This is done randomly to prevent all layers
238
+ from doing it at the same time. Early in training we may use
239
+ higher probabilities than this; it will decay to this value.
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ num_channels: int,
245
+ channel_dim: int,
246
+ min_positive: float = 0.05,
247
+ max_positive: float = 0.95,
248
+ max_factor: float = 0.04,
249
+ sign_gain_factor: float = 0.01,
250
+ scale_gain_factor: float = 0.02,
251
+ min_abs: float = 0.2,
252
+ max_abs: float = 100.0,
253
+ min_prob: float = 0.1,
254
+ ):
255
+ super(ActivationBalancer, self).__init__()
256
+ self.num_channels = num_channels
257
+ self.channel_dim = channel_dim
258
+ self.min_positive = min_positive
259
+ self.max_positive = max_positive
260
+ self.max_factor = max_factor
261
+ self.min_abs = min_abs
262
+ self.max_abs = max_abs
263
+ self.min_prob = min_prob
264
+ self.sign_gain_factor = sign_gain_factor
265
+ self.scale_gain_factor = scale_gain_factor
266
+
267
+ # count measures how many times the forward() function has been called.
268
+ # We occasionally sync this to a tensor called `count`, that exists to
269
+ # make sure it is synced to disk when we load and save the model.
270
+ self.cpu_count = 0
271
+ self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
272
+
273
+ def forward(self, x: Tensor) -> Tensor:
274
+ if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing():
275
+ return _no_op(x)
276
+
277
+ count = self.cpu_count
278
+ self.cpu_count += 1
279
+
280
+ if random.random() < 0.01:
281
+ # Occasionally sync self.cpu_count with self.count.
282
+ # count affects the decay of 'prob'. don't do this on every iter,
283
+ # because syncing with the GPU is slow.
284
+ self.cpu_count = max(self.cpu_count, self.count.item())
285
+ self.count.fill_(self.cpu_count)
286
+
287
+ # the prob of doing some work exponentially decreases from 0.5 till it hits
288
+ # a floor at min_prob (==0.1, by default)
289
+ prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
290
+
291
+ if random.random() < prob:
292
+ sign_gain_factor = 0.5
293
+ if self.min_positive != 0.0 or self.max_positive != 1.0:
294
+ sign_factor = _compute_sign_factor(
295
+ x,
296
+ self.channel_dim,
297
+ self.min_positive,
298
+ self.max_positive,
299
+ gain_factor=self.sign_gain_factor / prob,
300
+ max_factor=self.max_factor,
301
+ )
302
+ else:
303
+ sign_factor = None
304
+
305
+ scale_factor = _compute_scale_factor(
306
+ x.detach(),
307
+ self.channel_dim,
308
+ min_abs=self.min_abs,
309
+ max_abs=self.max_abs,
310
+ gain_factor=self.scale_gain_factor / prob,
311
+ max_factor=self.max_factor,
312
+ )
313
+ return ActivationBalancerFunction.apply(
314
+ x,
315
+ scale_factor,
316
+ sign_factor,
317
+ self.channel_dim,
318
+ )
319
+ else:
320
+ return _no_op(x)
321
+
322
+
323
+ def BalancedDoubleSwish(
324
+ d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
325
+ ) -> nn.Sequential:
326
+ """
327
+ ActivationBalancer -> DoubleSwish
328
+ """
329
+ balancer = ActivationBalancer(
330
+ d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
331
+ )
332
+ return nn.Sequential(
333
+ balancer,
334
+ DoubleSwish(),
335
+ )
AR/modules/transformer.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py
2
+ import copy
3
+ import numbers
4
+ from functools import partial
5
+ from typing import Any
6
+ from typing import Callable
7
+ from typing import List
8
+ from typing import Optional
9
+ from typing import Tuple
10
+ from typing import Union
11
+
12
+ import torch
13
+ from AR.modules.activation import MultiheadAttention
14
+ from AR.modules.scaling import BalancedDoubleSwish
15
+ from torch import nn
16
+ from torch import Tensor
17
+ from torch.nn import functional as F
18
+
19
+ _shape_t = Union[int, List[int], torch.Size]
20
+
21
+
22
+ class LayerNorm(nn.Module):
23
+ __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
24
+ normalized_shape: Tuple[int, ...]
25
+ eps: float
26
+ elementwise_affine: bool
27
+
28
+ def __init__(
29
+ self,
30
+ normalized_shape: _shape_t,
31
+ eps: float = 1e-5,
32
+ elementwise_affine: bool = True,
33
+ device=None,
34
+ dtype=None,
35
+ ) -> None:
36
+ factory_kwargs = {"device": device, "dtype": dtype}
37
+ super(LayerNorm, self).__init__()
38
+ if isinstance(normalized_shape, numbers.Integral):
39
+ # mypy error: incompatible types in assignment
40
+ normalized_shape = (normalized_shape,) # type: ignore[assignment]
41
+ self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
42
+ self.eps = eps
43
+ self.elementwise_affine = elementwise_affine
44
+ if self.elementwise_affine:
45
+ self.weight = nn.Parameter(
46
+ torch.empty(self.normalized_shape, **factory_kwargs)
47
+ )
48
+ self.bias = nn.Parameter(
49
+ torch.empty(self.normalized_shape, **factory_kwargs)
50
+ )
51
+ else:
52
+ self.register_parameter("weight", None)
53
+ self.register_parameter("bias", None)
54
+
55
+ self.reset_parameters()
56
+
57
+ def reset_parameters(self) -> None:
58
+ if self.elementwise_affine:
59
+ nn.init.ones_(self.weight)
60
+ nn.init.zeros_(self.bias)
61
+
62
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
63
+ if isinstance(input, tuple):
64
+ input, embedding = input
65
+ return (
66
+ F.layer_norm(
67
+ input,
68
+ self.normalized_shape,
69
+ self.weight,
70
+ self.bias,
71
+ self.eps,
72
+ ),
73
+ embedding,
74
+ )
75
+
76
+ assert embedding is None
77
+ return F.layer_norm(
78
+ input, self.normalized_shape, self.weight, self.bias, self.eps
79
+ )
80
+
81
+ def extra_repr(self) -> str:
82
+ return (
83
+ "{normalized_shape}, eps={eps}, "
84
+ "elementwise_affine={elementwise_affine}".format(**self.__dict__)
85
+ )
86
+
87
+
88
+ class IdentityNorm(nn.Module):
89
+ def __init__(
90
+ self,
91
+ d_model: int,
92
+ eps: float = 1e-5,
93
+ device=None,
94
+ dtype=None,
95
+ ) -> None:
96
+ super(IdentityNorm, self).__init__()
97
+
98
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
99
+ if isinstance(input, tuple):
100
+ return input
101
+
102
+ assert embedding is None
103
+ return input
104
+
105
+
106
+ class TransformerEncoder(nn.Module):
107
+ r"""TransformerEncoder is a stack of N encoder layers. Users can build the
108
+ BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
109
+
110
+ Args:
111
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
112
+ num_layers: the number of sub-encoder-layers in the encoder (required).
113
+ norm: the layer normalization component (optional).
114
+ enable_nested_tensor: if True, input will automatically convert to nested tensor
115
+ (and convert back on output). This will improve the overall performance of
116
+ TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
117
+
118
+ Examples::
119
+ >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
120
+ >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
121
+ >>> src = torch.rand(10, 32, 512)
122
+ >>> out = transformer_encoder(src)
123
+ """
124
+ __constants__ = ["norm"]
125
+
126
+ def __init__(self, encoder_layer, num_layers, norm=None):
127
+ super(TransformerEncoder, self).__init__()
128
+ self.layers = _get_clones(encoder_layer, num_layers)
129
+ self.num_layers = num_layers
130
+ self.norm = norm
131
+
132
+ def forward(
133
+ self,
134
+ src: Tensor,
135
+ mask: Optional[Tensor] = None,
136
+ src_key_padding_mask: Optional[Tensor] = None,
137
+ return_layer_states: bool = False,
138
+ cache=None,
139
+ ) -> Tensor:
140
+ r"""Pass the input through the encoder layers in turn.
141
+
142
+ Args:
143
+ src: the sequence to the encoder (required).
144
+ mask: the mask for the src sequence (optional).
145
+ src_key_padding_mask: the mask for the src keys per batch (optional).
146
+ return_layer_states: return layers' state (optional).
147
+
148
+ Shape:
149
+ see the docs in Transformer class.
150
+ """
151
+ if return_layer_states:
152
+ layer_states = [] # layers' output
153
+ output = src
154
+ for mod in self.layers:
155
+ output = mod(
156
+ output,
157
+ src_mask=mask,
158
+ src_key_padding_mask=src_key_padding_mask,
159
+ cache=cache,
160
+ )
161
+ layer_states.append(output[0])
162
+
163
+ if self.norm is not None:
164
+ output = self.norm(output)
165
+
166
+ return layer_states, output
167
+
168
+ output = src
169
+ for mod in self.layers:
170
+ output = mod(
171
+ output,
172
+ src_mask=mask,
173
+ src_key_padding_mask=src_key_padding_mask,
174
+ cache=cache,
175
+ )
176
+
177
+ if self.norm is not None:
178
+ output = self.norm(output)
179
+
180
+ return output
181
+
182
+
183
+ class TransformerEncoderLayer(nn.Module):
184
+ __constants__ = ["batch_first", "norm_first"]
185
+
186
+ def __init__(
187
+ self,
188
+ d_model: int,
189
+ nhead: int,
190
+ dim_feedforward: int = 2048,
191
+ dropout: float = 0.1,
192
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
193
+ batch_first: bool = False,
194
+ norm_first: bool = False,
195
+ device=None,
196
+ dtype=None,
197
+ linear1_self_attention_cls: nn.Module = nn.Linear,
198
+ linear2_self_attention_cls: nn.Module = nn.Linear,
199
+ linear1_feedforward_cls: nn.Module = nn.Linear,
200
+ linear2_feedforward_cls: nn.Module = nn.Linear,
201
+ layer_norm_cls: nn.Module = LayerNorm,
202
+ layer_norm_eps: float = 1e-5,
203
+ adaptive_layer_norm=False,
204
+ ) -> None:
205
+ factory_kwargs = {"device": device, "dtype": dtype}
206
+ super(TransformerEncoderLayer, self).__init__()
207
+ # print(233333333333,d_model,nhead)
208
+ # import os
209
+ # os._exit(2333333)
210
+ self.self_attn = MultiheadAttention(
211
+ d_model, # 512 16
212
+ nhead,
213
+ dropout=dropout,
214
+ batch_first=batch_first,
215
+ linear1_cls=linear1_self_attention_cls,
216
+ linear2_cls=linear2_self_attention_cls,
217
+ **factory_kwargs,
218
+ )
219
+
220
+ # Implementation of Feedforward model
221
+ self.linear1 = linear1_feedforward_cls(
222
+ d_model, dim_feedforward, **factory_kwargs
223
+ )
224
+ self.dropout = nn.Dropout(dropout)
225
+ self.linear2 = linear2_feedforward_cls(
226
+ dim_feedforward, d_model, **factory_kwargs
227
+ )
228
+
229
+ self.norm_first = norm_first
230
+ self.dropout1 = nn.Dropout(dropout)
231
+ self.dropout2 = nn.Dropout(dropout)
232
+
233
+ # Legacy string support for activation function.
234
+ if isinstance(activation, str):
235
+ activation = _get_activation_fn(activation)
236
+ elif isinstance(activation, partial):
237
+ activation = activation(d_model)
238
+ elif activation == BalancedDoubleSwish:
239
+ activation = BalancedDoubleSwish(d_model)
240
+
241
+ # # We can't test self.activation in forward() in TorchScript,
242
+ # # so stash some information about it instead.
243
+ # if activation is F.relu or isinstance(activation, torch.nn.ReLU):
244
+ # self.activation_relu_or_gelu = 1
245
+ # elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
246
+ # self.activation_relu_or_gelu = 2
247
+ # else:
248
+ # self.activation_relu_or_gelu = 0
249
+ self.activation = activation
250
+
251
+ norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
252
+ if layer_norm_cls == IdentityNorm:
253
+ norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
254
+ else:
255
+ norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
256
+
257
+ if adaptive_layer_norm:
258
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
259
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
260
+ else:
261
+ self.norm1 = norm1
262
+ self.norm2 = norm2
263
+
264
+ def __setstate__(self, state):
265
+ super(TransformerEncoderLayer, self).__setstate__(state)
266
+ if not hasattr(self, "activation"):
267
+ self.activation = F.relu
268
+
269
+ def forward(
270
+ self,
271
+ src: Tensor,
272
+ src_mask: Optional[Tensor] = None,
273
+ src_key_padding_mask: Optional[Tensor] = None,
274
+ cache=None,
275
+ ) -> Tensor:
276
+ r"""Pass the input through the encoder layer.
277
+
278
+ Args:
279
+ src: the sequence to the encoder layer (required).
280
+ src_mask: the mask for the src sequence (optional).
281
+ src_key_padding_mask: the mask for the src keys per batch (optional).
282
+
283
+ Shape:
284
+ see the docs in Transformer class.
285
+ """
286
+ x, stage_embedding = src, None
287
+ is_src_tuple = False
288
+ if isinstance(src, tuple):
289
+ x, stage_embedding = src
290
+ is_src_tuple = True
291
+
292
+ if src_key_padding_mask is not None:
293
+ _skpm_dtype = src_key_padding_mask.dtype
294
+ if _skpm_dtype != torch.bool and not torch.is_floating_point(
295
+ src_key_padding_mask
296
+ ):
297
+ raise AssertionError(
298
+ "only bool and floating types of key_padding_mask are supported"
299
+ )
300
+
301
+ if self.norm_first:
302
+ x = x + self._sa_block(
303
+ self.norm1(x, stage_embedding),
304
+ src_mask,
305
+ src_key_padding_mask,
306
+ cache=cache,
307
+ )
308
+ x = x + self._ff_block(self.norm2(x, stage_embedding))
309
+ else:
310
+ x = self.norm1(
311
+ x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache),
312
+ stage_embedding,
313
+ )
314
+ x = self.norm2(x + self._ff_block(x), stage_embedding)
315
+
316
+ if is_src_tuple:
317
+ return (x, stage_embedding)
318
+ return x
319
+
320
+ # self-attention block
321
+ def _sa_block(
322
+ self,
323
+ x: Tensor,
324
+ attn_mask: Optional[Tensor],
325
+ key_padding_mask: Optional[Tensor],
326
+ cache=None,
327
+ ) -> Tensor:
328
+ # print(x.shape,attn_mask.shape,key_padding_mask)
329
+ # torch.Size([1, 188, 512]) torch.Size([188, 188]) None
330
+ # import os
331
+ # os._exit(23333)
332
+ x = self.self_attn(
333
+ x,
334
+ x,
335
+ x,
336
+ attn_mask=attn_mask,
337
+ key_padding_mask=key_padding_mask,
338
+ need_weights=False,
339
+ cache=cache,
340
+ )[0]
341
+ return self.dropout1(x)
342
+
343
+ # feed forward block
344
+ def _ff_block(self, x: Tensor) -> Tensor:
345
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
346
+ return self.dropout2(x)
347
+
348
+
349
+ class AdaptiveLayerNorm(nn.Module):
350
+ r"""Adaptive Layer Normalization"""
351
+
352
+ def __init__(self, d_model, norm) -> None:
353
+ super(AdaptiveLayerNorm, self).__init__()
354
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
355
+ self.norm = norm
356
+ self.d_model = d_model
357
+ self.eps = self.norm.eps
358
+
359
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
360
+ if isinstance(input, tuple):
361
+ input, embedding = input
362
+ weight, bias = torch.split(
363
+ self.project_layer(embedding),
364
+ split_size_or_sections=self.d_model,
365
+ dim=-1,
366
+ )
367
+ return (weight * self.norm(input) + bias, embedding)
368
+
369
+ weight, bias = torch.split(
370
+ self.project_layer(embedding),
371
+ split_size_or_sections=self.d_model,
372
+ dim=-1,
373
+ )
374
+ return weight * self.norm(input) + bias
375
+
376
+
377
+ def _get_clones(module, N):
378
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
AR/text_processing/__init__.py ADDED
File without changes
AR/text_processing/phonemizer.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/text_processing/phonemizer.py
2
+ import itertools
3
+ import re
4
+ from typing import Dict
5
+ from typing import List
6
+
7
+ import regex
8
+ from gruut import sentences
9
+ from gruut.const import Sentence
10
+ from gruut.const import Word
11
+ from AR.text_processing.symbols import SYMBOL_TO_ID
12
+
13
+
14
+ class GruutPhonemizer:
15
+ def __init__(self, language: str):
16
+ self._phonemizer = sentences
17
+ self.lang = language
18
+ self.symbol_to_id = SYMBOL_TO_ID
19
+ self._special_cases_dict: Dict[str] = {
20
+ r"\.\.\.": "... ",
21
+ ";": "; ",
22
+ ":": ": ",
23
+ ",": ", ",
24
+ r"\.": ". ",
25
+ "!": "! ",
26
+ r"\?": "? ",
27
+ "—": "—",
28
+ "…": "… ",
29
+ "«": "«",
30
+ "»": "»",
31
+ }
32
+ self._punctuation_regexp: str = (
33
+ rf"([{''.join(self._special_cases_dict.keys())}])"
34
+ )
35
+
36
+ def _normalize_punctuation(self, text: str) -> str:
37
+ text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
38
+ text = regex.sub(rf"{self._punctuation_regexp}(\pL)", r"\1 \2", text)
39
+ text = regex.sub(r"\pZ+", r" ", text)
40
+ return text.strip()
41
+
42
+ def _convert_punctuation(self, word: Word) -> str:
43
+ if not word.phonemes:
44
+ return ""
45
+ if word.phonemes[0] in ["‖", "|"]:
46
+ return word.text.strip()
47
+
48
+ phonemes = "".join(word.phonemes)
49
+ # remove modifier characters ˈˌː with regex
50
+ phonemes = re.sub(r"[ˈˌː͡]", "", phonemes)
51
+ return phonemes.strip()
52
+
53
+ def phonemize(self, text: str, espeak: bool = False) -> str:
54
+ text_to_phonemize: str = self._normalize_punctuation(text)
55
+ sents: List[Sentence] = [
56
+ sent
57
+ for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)
58
+ ]
59
+ words: List[str] = [
60
+ self._convert_punctuation(word) for word in itertools.chain(*sents)
61
+ ]
62
+ return " ".join(words)
63
+
64
+ def transform(self, phonemes):
65
+ # convert phonemes to ids
66
+ # dictionary is in symbols.py
67
+ return [self.symbol_to_id[p] for p in phonemes if p in self.symbol_to_id.keys()]
68
+
69
+
70
+ if __name__ == "__main__":
71
+ phonemizer = GruutPhonemizer("en-us")
72
+ # text -> IPA
73
+ phonemes = phonemizer.phonemize("Hello, wor-ld ?")
74
+ print("phonemes:", phonemes)
75
+ print("len(phonemes):", len(phonemes))
76
+ phoneme_ids = phonemizer.transform(phonemes)
77
+ print("phoneme_ids:", phoneme_ids)
78
+ print("len(phoneme_ids):", len(phoneme_ids))
AR/text_processing/symbols.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/text_processing/symbols.py
2
+ PAD = "_"
3
+ PUNCTUATION = ';:,.!?¡¿—…"«»“” '
4
+ LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
5
+ IPA_LETTERS = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
6
+ SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
7
+ SPACE_ID = SYMBOLS.index(" ")
8
+ SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)}
9
+ ID_TO_SYMBOL = {i: s for i, s in enumerate(SYMBOLS)}
AR/utils/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ def str2bool(str):
5
+ return True if str.lower() == 'true' else False
6
+
7
+
8
+ def get_newest_ckpt(string_list):
9
+ # 定义一个正则表达式模式,用于匹配字符串中的数字
10
+ pattern = r'epoch=(\d+)-step=(\d+)\.ckpt'
11
+
12
+ # 使用正则表达式提取每个字符串中的数字信息,并创建一个包含元组的列表
13
+ extracted_info = []
14
+ for string in string_list:
15
+ match = re.match(pattern, string)
16
+ if match:
17
+ epoch = int(match.group(1))
18
+ step = int(match.group(2))
19
+ extracted_info.append((epoch, step, string))
20
+ # 按照 epoch 后面的数字和 step 后面的数字进行排序
21
+ sorted_info = sorted(
22
+ extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
23
+ # 获取最新的 ckpt 文件名
24
+ newest_ckpt = sorted_info[0][2]
25
+ return newest_ckpt
26
+
27
+
28
+ # 文本存在且不为空时 return True
29
+ def check_txt_file(file_path):
30
+ try:
31
+ with open(file_path, 'r') as file:
32
+ text = file.readline().strip()
33
+ assert text.strip() != ''
34
+ return text
35
+ except Exception:
36
+ return False
37
+ return False
AR/utils/initialize.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Initialize modules for espnet2 neural networks."""
3
+ import torch
4
+ from typeguard import check_argument_types
5
+
6
+
7
+ def initialize(model: torch.nn.Module, init: str):
8
+ """Initialize weights of a neural network module.
9
+
10
+ Parameters are initialized using the given method or distribution.
11
+
12
+ Custom initialization routines can be implemented into submodules
13
+ as function `espnet_initialization_fn` within the custom module.
14
+
15
+ Args:
16
+ model: Target.
17
+ init: Method of initialization.
18
+ """
19
+ assert check_argument_types()
20
+ print("init with", init)
21
+
22
+ # weight init
23
+ for p in model.parameters():
24
+ if p.dim() > 1:
25
+ if init == "xavier_uniform":
26
+ torch.nn.init.xavier_uniform_(p.data)
27
+ elif init == "xavier_normal":
28
+ torch.nn.init.xavier_normal_(p.data)
29
+ elif init == "kaiming_uniform":
30
+ torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
31
+ elif init == "kaiming_normal":
32
+ torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
33
+ else:
34
+ raise ValueError("Unknown initialization: " + init)
35
+ # bias init
36
+ for name, p in model.named_parameters():
37
+ if ".bias" in name and p.dim() == 1:
38
+ p.data.zero_()
AR/utils/io.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import torch
4
+ import yaml
5
+
6
+
7
+ def load_yaml_config(path):
8
+ with open(path) as f:
9
+ config = yaml.full_load(f)
10
+ return config
11
+
12
+
13
+ def save_config_to_yaml(config, path):
14
+ assert path.endswith(".yaml")
15
+ with open(path, "w") as f:
16
+ f.write(yaml.dump(config))
17
+ f.close()
18
+
19
+
20
+ def write_args(args, path):
21
+ args_dict = dict(
22
+ (name, getattr(args, name)) for name in dir(args) if not name.startswith("_")
23
+ )
24
+ with open(path, "a") as args_file:
25
+ args_file.write("==> torch version: {}\n".format(torch.__version__))
26
+ args_file.write(
27
+ "==> cudnn version: {}\n".format(torch.backends.cudnn.version())
28
+ )
29
+ args_file.write("==> Cmd:\n")
30
+ args_file.write(str(sys.argv))
31
+ args_file.write("\n==> args:\n")
32
+ for k, v in sorted(args_dict.items()):
33
+ args_file.write(" %s: %s\n" % (str(k), str(v)))
34
+ args_file.close()
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: Gpt Sovits Demo
3
- emoji: 🚀
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
@@ -10,4 +10,6 @@ pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
1
  ---
2
+ title: GPT-SoVITS Zero-shot TTS Demo
3
+ emoji: 🚀🚀🚀
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
 
10
  license: mit
11
  ---
12
 
13
+ Original:
14
+
15
+ https://github.com/RVC-Boss/GPT-SoVITS
app.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/RVC-Boss/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
2
+ import os
3
+
4
+ gpt_path = os.environ.get(
5
+ "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
6
+ )
7
+ sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
8
+ cnhubert_base_path = os.environ.get(
9
+ "cnhubert_base_path", "pretrained_models/chinese-hubert-base"
10
+ )
11
+ bert_path = os.environ.get(
12
+ "bert_path", "pretrained_models/chinese-roberta-wwm-ext-large"
13
+ )
14
+
15
+ if "_CUDA_VISIBLE_DEVICES" in os.environ:
16
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
17
+ is_half = eval(os.environ.get("is_half", "True"))
18
+
19
+
20
+ import gradio as gr
21
+ import librosa
22
+ import numpy as np
23
+ import torch
24
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
25
+
26
+ from feature_extractor import cnhubert
27
+
28
+ cnhubert.cnhubert_base_path = cnhubert_base_path
29
+ from time import time as ttime
30
+
31
+ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
32
+ from module.mel_processing import spectrogram_torch
33
+ from module.models import SynthesizerTrn
34
+ from my_utils import load_audio
35
+ from text import cleaned_text_to_sequence
36
+ from text.cleaner import clean_text
37
+
38
+ device = "cuda"
39
+ tokenizer = AutoTokenizer.from_pretrained(bert_path)
40
+ bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
41
+ if is_half == True:
42
+ bert_model = bert_model.half().to(device)
43
+ else:
44
+ bert_model = bert_model.to(device)
45
+
46
+
47
+ # bert_model=bert_model.to(device)
48
+ def get_bert_feature(text, word2ph):
49
+ with torch.no_grad():
50
+ inputs = tokenizer(text, return_tensors="pt")
51
+ for i in inputs:
52
+ inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model
53
+ res = bert_model(**inputs, output_hidden_states=True)
54
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
55
+ assert len(word2ph) == len(text)
56
+ phone_level_feature = []
57
+ for i in range(len(word2ph)):
58
+ repeat_feature = res[i].repeat(word2ph[i], 1)
59
+ phone_level_feature.append(repeat_feature)
60
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
61
+ # if(is_half==True):phone_level_feature=phone_level_feature.half()
62
+ return phone_level_feature.T
63
+
64
+
65
+ n_semantic = 1024
66
+ dict_s2 = torch.load(sovits_path, map_location="cpu")
67
+ hps = dict_s2["config"]
68
+
69
+
70
+ class DictToAttrRecursive:
71
+ def __init__(self, input_dict):
72
+ for key, value in input_dict.items():
73
+ if isinstance(value, dict):
74
+ # 如果值是字典,递归调用构造函数
75
+ setattr(self, key, DictToAttrRecursive(value))
76
+ else:
77
+ setattr(self, key, value)
78
+
79
+
80
+ hps = DictToAttrRecursive(hps)
81
+ hps.model.semantic_frame_rate = "25hz"
82
+ dict_s1 = torch.load(gpt_path, map_location="cpu")
83
+ config = dict_s1["config"]
84
+ ssl_model = cnhubert.get_model()
85
+ if is_half == True:
86
+ ssl_model = ssl_model.half().to(device)
87
+ else:
88
+ ssl_model = ssl_model.to(device)
89
+
90
+ vq_model = SynthesizerTrn(
91
+ hps.data.filter_length // 2 + 1,
92
+ hps.train.segment_size // hps.data.hop_length,
93
+ n_speakers=hps.data.n_speakers,
94
+ **hps.model
95
+ )
96
+ if is_half == True:
97
+ vq_model = vq_model.half().to(device)
98
+ else:
99
+ vq_model = vq_model.to(device)
100
+ vq_model.eval()
101
+ print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
102
+ hz = 50
103
+ max_sec = config["data"]["max_sec"]
104
+ # t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
105
+ t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
106
+ t2s_model.load_state_dict(dict_s1["weight"])
107
+ if is_half == True:
108
+ t2s_model = t2s_model.half()
109
+ t2s_model = t2s_model.to(device)
110
+ t2s_model.eval()
111
+ total = sum([param.nelement() for param in t2s_model.parameters()])
112
+ print("Number of parameter: %.2fM" % (total / 1e6))
113
+
114
+
115
+ def get_spepc(hps, filename):
116
+ audio = load_audio(filename, int(hps.data.sampling_rate))
117
+ audio = torch.FloatTensor(audio)
118
+ audio_norm = audio
119
+ audio_norm = audio_norm.unsqueeze(0)
120
+ spec = spectrogram_torch(
121
+ audio_norm,
122
+ hps.data.filter_length,
123
+ hps.data.sampling_rate,
124
+ hps.data.hop_length,
125
+ hps.data.win_length,
126
+ center=False,
127
+ )
128
+ return spec
129
+
130
+
131
+ dict_language = {"Chinese": "zh", "English": "en", "Japanese": "ja"}
132
+
133
+
134
+ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
135
+ if len(prompt_text) > 100 or len(text) > 100:
136
+ return
137
+ t0 = ttime()
138
+ prompt_text = prompt_text.strip("\n")
139
+ prompt_language, text = prompt_language, text.strip("\n")
140
+ with torch.no_grad():
141
+ wav16k, _ = librosa.load(ref_wav_path, sr=16000) # 派蒙
142
+ # length of wav16k in sec should be in 60s
143
+ if len(wav16k) < 16000 * 60:
144
+ return
145
+ wav16k = wav16k[: int(hps.data.sampling_rate * max_sec)]
146
+ wav16k = torch.from_numpy(wav16k)
147
+ if is_half == True:
148
+ wav16k = wav16k.half().to(device)
149
+ else:
150
+ wav16k = wav16k.to(device)
151
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
152
+ "last_hidden_state"
153
+ ].transpose(
154
+ 1, 2
155
+ ) # .float()
156
+ codes = vq_model.extract_latent(ssl_content)
157
+ prompt_semantic = codes[0, 0]
158
+ t1 = ttime()
159
+ prompt_language = dict_language[prompt_language]
160
+ text_language = dict_language[text_language]
161
+ phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
162
+ phones1 = cleaned_text_to_sequence(phones1)
163
+ texts = text.split("\n")
164
+ audio_opt = []
165
+ zero_wav = np.zeros(
166
+ int(hps.data.sampling_rate * 0.3),
167
+ dtype=np.float16 if is_half == True else np.float32,
168
+ )
169
+ for text in texts:
170
+ phones2, word2ph2, norm_text2 = clean_text(text, text_language)
171
+ phones2 = cleaned_text_to_sequence(phones2)
172
+ if prompt_language == "zh":
173
+ bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
174
+ else:
175
+ bert1 = torch.zeros(
176
+ (1024, len(phones1)),
177
+ dtype=torch.float16 if is_half == True else torch.float32,
178
+ ).to(device)
179
+ if text_language == "zh":
180
+ bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
181
+ else:
182
+ bert2 = torch.zeros((1024, len(phones2))).to(bert1)
183
+ bert = torch.cat([bert1, bert2], 1)
184
+
185
+ all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
186
+ bert = bert.to(device).unsqueeze(0)
187
+ all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
188
+ prompt = prompt_semantic.unsqueeze(0).to(device)
189
+ t2 = ttime()
190
+ with torch.no_grad():
191
+ # pred_semantic = t2s_model.model.infer(
192
+ pred_semantic, idx = t2s_model.model.infer_panel(
193
+ all_phoneme_ids,
194
+ all_phoneme_len,
195
+ prompt,
196
+ bert,
197
+ # prompt_phone_len=ph_offset,
198
+ top_k=config["inference"]["top_k"],
199
+ early_stop_num=hz * max_sec,
200
+ )
201
+ t3 = ttime()
202
+ # print(pred_semantic.shape,idx)
203
+ pred_semantic = pred_semantic[:, -idx:].unsqueeze(
204
+ 0
205
+ ) # .unsqueeze(0)#mq要多unsqueeze一次
206
+ refer = get_spepc(hps, ref_wav_path) # .to(device)
207
+ if is_half == True:
208
+ refer = refer.half().to(device)
209
+ else:
210
+ refer = refer.to(device)
211
+ # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
212
+ audio = (
213
+ vq_model.decode(
214
+ pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
215
+ )
216
+ .detach()
217
+ .cpu()
218
+ .numpy()[0, 0]
219
+ ) ###试试重建不带上prompt部分
220
+ audio_opt.append(audio)
221
+ audio_opt.append(zero_wav)
222
+ t4 = ttime()
223
+ print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
224
+ yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
225
+ np.int16
226
+ )
227
+
228
+
229
+ initial_md = """
230
+ # GPT-SoVITS Zero-shot TTS Demo
231
+
232
+ https://github.com/RVC-Boss/GPT-SoVITS
233
+
234
+ *I'm not the author of this model, and I just borrowed it to make a demo.*
235
+
236
+ - *Input text is limited to 100 characters.*
237
+ - *Input audio is limited to 60 seconds.*
238
+
239
+ **License**
240
+
241
+ https://github.com/RVC-Boss/GPT-SoVITS/blob/main/LICENSE
242
+
243
+ This software is open source under the MIT License, the author does not have any control over the software, and the user is solely responsible for the use of the software and for the distribution of the sound derived from the software.
244
+ If you do not agree with these terms and conditions, you may not use or reference any of the code or files in the package.
245
+ """
246
+
247
+ with gr.Blocks(title="GPT-SoVITS Zero-shot TTS Demo") as app:
248
+ gr.Markdown(initial_md)
249
+ with gr.Group():
250
+ gr.Markdown(value="*Upload reference audio")
251
+ with gr.Row():
252
+ inp_ref = gr.Audio(label="Reference audio", type="filepath")
253
+ prompt_text = gr.Textbox(label="Transcription of reference audio")
254
+ prompt_language = gr.Dropdown(
255
+ label="Language of reference audio",
256
+ choices=["Chinese", "English", "Japanese"],
257
+ value="Japanese",
258
+ )
259
+ gr.Markdown(value="*Text to synthesize")
260
+ with gr.Row():
261
+ text = gr.Textbox(label="Text to synthesize")
262
+ text_language = gr.Dropdown(
263
+ label="Language of text",
264
+ choices=["Chinese", "English", "Japanese"],
265
+ value="Japanese",
266
+ )
267
+ inference_button = gr.Button("Synthesize", variant="primary")
268
+ output = gr.Audio(label="Result")
269
+ inference_button.click(
270
+ get_tts_wav,
271
+ [inp_ref, prompt_text, prompt_language, text, text_language],
272
+ [output],
273
+ )
274
+
275
+ app.launch(inbrowser=True)
configs/s1.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ seed: 1234
3
+ epochs: 300
4
+ batch_size: 8
5
+ gradient_accumulation: 4
6
+ save_every_n_epoch: 1
7
+ precision: 16
8
+ gradient_clip: 1.0
9
+ optimizer:
10
+ lr: 0.01
11
+ lr_init: 0.00001
12
+ lr_end: 0.0001
13
+ warmup_steps: 2000
14
+ decay_steps: 40000
15
+ data:
16
+ max_eval_sample: 8
17
+ max_sec: 54
18
+ num_workers: 1
19
+ pad_val: 1024 # same with EOS in model
20
+ model:
21
+ vocab_size: 1025
22
+ phoneme_vocab_size: 512
23
+ embedding_dim: 512
24
+ hidden_dim: 512
25
+ head: 16
26
+ linear_units: 2048
27
+ n_layer: 12
28
+ dropout: 0
29
+ EOS: 1024
30
+ inference:
31
+ top_k: 5
configs/s1big.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ seed: 1234
3
+ epochs: 300
4
+ batch_size: 8
5
+ gradient_accumulation: 4
6
+ save_every_n_epoch: 1
7
+ precision: 16-mixed
8
+ gradient_clip: 1.0
9
+ optimizer:
10
+ lr: 0.01
11
+ lr_init: 0.00001
12
+ lr_end: 0.0001
13
+ warmup_steps: 2000
14
+ decay_steps: 40000
15
+ data:
16
+ max_eval_sample: 8
17
+ max_sec: 54
18
+ num_workers: 1
19
+ pad_val: 1024 # same with EOS in model
20
+ model:
21
+ vocab_size: 1025
22
+ phoneme_vocab_size: 512
23
+ embedding_dim: 1024
24
+ hidden_dim: 1024
25
+ head: 16
26
+ linear_units: 2048
27
+ n_layer: 16
28
+ dropout: 0
29
+ EOS: 1024
30
+ inference:
31
+ top_k: 5
configs/s1big2.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ seed: 1234
3
+ epochs: 300
4
+ batch_size: 12
5
+ gradient_accumulation: 4
6
+ save_every_n_epoch: 1
7
+ precision: 16-mixed
8
+ gradient_clip: 1.0
9
+ optimizer:
10
+ lr: 0.01
11
+ lr_init: 0.00001
12
+ lr_end: 0.0001
13
+ warmup_steps: 2000
14
+ decay_steps: 40000
15
+ data:
16
+ max_eval_sample: 8
17
+ max_sec: 54
18
+ num_workers: 1
19
+ pad_val: 1024 # same with EOS in model
20
+ model:
21
+ vocab_size: 1025
22
+ phoneme_vocab_size: 512
23
+ embedding_dim: 1024
24
+ hidden_dim: 1024
25
+ head: 16
26
+ linear_units: 2048
27
+ n_layer: 6
28
+ dropout: 0
29
+ EOS: 1024
30
+ inference:
31
+ top_k: 5
configs/s1longer.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ seed: 1234
3
+ epochs: 20
4
+ batch_size: 8
5
+ save_every_n_epoch: 1
6
+ precision: 16-mixed
7
+ gradient_clip: 1.0
8
+ optimizer:
9
+ lr: 0.01
10
+ lr_init: 0.00001
11
+ lr_end: 0.0001
12
+ warmup_steps: 2000
13
+ decay_steps: 40000
14
+ data:
15
+ max_eval_sample: 8
16
+ max_sec: 54
17
+ num_workers: 4
18
+ pad_val: 1024 # same with EOS in model
19
+ model:
20
+ vocab_size: 1025
21
+ phoneme_vocab_size: 512
22
+ embedding_dim: 512
23
+ hidden_dim: 512
24
+ head: 16
25
+ linear_units: 2048
26
+ n_layer: 24
27
+ dropout: 0
28
+ EOS: 1024
29
+ random_bert: 0
30
+ inference:
31
+ top_k: 5
configs/s1mq.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ seed: 1234
3
+ epochs: 100
4
+ batch_size: 6
5
+ gradient_accumulation: 4
6
+ save_every_n_epoch: 1
7
+ precision: 32
8
+ gradient_clip: 1.0
9
+ optimizer:
10
+ lr: 0.01
11
+ lr_init: 0.00001
12
+ lr_end: 0.0001
13
+ warmup_steps: 2000
14
+ decay_steps: 40000
15
+ data:
16
+ max_eval_sample: 8
17
+ max_sec: 40
18
+ num_workers: 1
19
+ pad_val: 1024 # same with EOS in model
20
+ model:
21
+ saving_path: "ckpt/"
22
+ resume_checkpoint: null
23
+ vocoder_config_path: "quantizer/new_ckpt/config.json"
24
+ vocoder_ckpt_path: "quantizer/new_ckpt/g_00600000"
25
+ datadir: "/home/liweiche/GigaSpeech/wavs"
26
+ metapath: "/home/liweiche/GigaSpeech/train2.json"
27
+ val_metapath: "/home/liweiche/GigaSpeech/dev2.json"
28
+ sampledir: "logs/"
29
+ pretrained_path: null
30
+ lr: 0.0001
31
+ batch_size: 200.0
32
+ train_bucket_size: 8192
33
+ training_step: 800000
34
+ optim_flat_percent: 0.0
35
+ warmup_step: 50
36
+ adam_beta1: 0.9
37
+ adam_beta2: 0.98
38
+ ffd_size: 3072
39
+ hidden_size: 768
40
+ enc_nlayers: 6
41
+ dec_nlayers: 6
42
+ nheads: 12
43
+ ar_layer: 4
44
+ ar_ffd_size: 1024
45
+ ar_hidden_size: 256
46
+ ar_nheads: 4
47
+ aligner_softmax_temp: 1.0
48
+ layer_norm_eps: 0.00001
49
+ speaker_embed_dropout: 0.05
50
+ label_smoothing: 0.0
51
+ val_check_interval: 5000
52
+ check_val_every_n_epoch: 1
53
+ precision: "fp16"
54
+ nworkers: 16
55
+ distributed: true
56
+ accelerator: "ddp"
57
+ version: null
58
+ accumulate_grad_batches: 1
59
+ use_repetition_token: true
60
+ use_repetition_gating: false
61
+ repetition_penalty: 1.0
62
+ sampling_temperature: 1.0
63
+ top_k: -1
64
+ min_top_k: 3
65
+ top_p: 0.8
66
+ sample_num: 4
67
+ length_penalty_max_length: 15000
68
+ length_penalty_max_prob: 0.95
69
+ max_input_length: 2048
70
+ max_output_length: 2000
71
+ sample_rate: 16000
72
+ n_codes: 1024
73
+ n_cluster_groups: 1
74
+ phone_context_window: 4
75
+ phoneset_size: 1000
76
+ inference:
77
+ top_k: 5
configs/s2.json ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 100,
4
+ "eval_interval": 500,
5
+ "seed": 1234,
6
+ "epochs": 100,
7
+ "learning_rate": 0.0001,
8
+ "betas": [
9
+ 0.8,
10
+ 0.99
11
+ ],
12
+ "eps": 1e-09,
13
+ "batch_size": 32,
14
+ "fp16_run": true,
15
+ "lr_decay": 0.999875,
16
+ "segment_size": 20480,
17
+ "init_lr_ratio": 1,
18
+ "warmup_epochs": 0,
19
+ "c_mel": 45,
20
+ "c_kl": 1.0,
21
+ "text_low_lr_rate": 0.4
22
+ },
23
+ "data": {
24
+ "max_wav_value": 32768.0,
25
+ "sampling_rate": 32000,
26
+ "filter_length": 2048,
27
+ "hop_length": 640,
28
+ "win_length": 2048,
29
+ "n_mel_channels": 128,
30
+ "mel_fmin": 0.0,
31
+ "mel_fmax": null,
32
+ "add_blank": true,
33
+ "n_speakers": 300,
34
+ "cleaned_text": true
35
+ },
36
+ "model": {
37
+ "inter_channels": 192,
38
+ "hidden_channels": 192,
39
+ "filter_channels": 768,
40
+ "n_heads": 2,
41
+ "n_layers": 6,
42
+ "kernel_size": 3,
43
+ "p_dropout": 0.1,
44
+ "resblock": "1",
45
+ "resblock_kernel_sizes": [
46
+ 3,
47
+ 7,
48
+ 11
49
+ ],
50
+ "resblock_dilation_sizes": [
51
+ [
52
+ 1,
53
+ 3,
54
+ 5
55
+ ],
56
+ [
57
+ 1,
58
+ 3,
59
+ 5
60
+ ],
61
+ [
62
+ 1,
63
+ 3,
64
+ 5
65
+ ]
66
+ ],
67
+ "upsample_rates": [
68
+ 10,
69
+ 8,
70
+ 2,
71
+ 2,
72
+ 2
73
+ ],
74
+ "upsample_initial_channel": 512,
75
+ "upsample_kernel_sizes": [
76
+ 16,
77
+ 16,
78
+ 8,
79
+ 2,
80
+ 2
81
+ ],
82
+ "n_layers_q": 3,
83
+ "use_spectral_norm": false,
84
+ "gin_channels": 512,
85
+ "semantic_frame_rate": "25hz",
86
+ "freeze_quantizer": true
87
+ },
88
+ "s2_ckpt_dir": "logs/s2/big2k1",
89
+ "content_module": "cnhubert"
90
+ }
configs/train.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gpu:
2
+ n_card: 1
3
+ n_process_per_card: 2
4
+ io:
5
+ text_path: D:\RVC1006\GPT-SoVITS\GPT_SoVITS
6
+ save_every_n_epoch: 1
7
+ precision: 16-mixed
8
+ gradient_clip: 1.0
9
+ optimizer:
10
+ lr: 0.01
11
+ lr_init: 0.00001
12
+ lr_end: 0.0001
13
+ warmup_steps: 2000
14
+ decay_steps: 40000
15
+ data:
16
+ max_eval_sample: 8
17
+ max_sec: 54
18
+ num_workers: 1
19
+ pad_val: 1024 # same with EOS in model
20
+ model:
21
+ vocab_size: 1025
22
+ phoneme_vocab_size: 512
23
+ embedding_dim: 512
24
+ hidden_dim: 512
25
+ head: 16
26
+ linear_units: 2048
27
+ n_layer: 24
28
+ dropout: 0
29
+ EOS: 1024
30
+ random_bert: 0
31
+ inference:
32
+ top_k: 5
feature_extractor/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from . import cnhubert, whisper_enc
2
+
3
+ content_module_map = {
4
+ 'cnhubert': cnhubert,
5
+ 'whisper': whisper_enc
6
+ }
feature_extractor/cnhubert.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import librosa
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import soundfile as sf
7
+ import logging
8
+
9
+ logging.getLogger("numba").setLevel(logging.WARNING)
10
+
11
+ from transformers import (
12
+ Wav2Vec2FeatureExtractor,
13
+ HubertModel,
14
+ )
15
+
16
+ import utils
17
+ import torch.nn as nn
18
+
19
+ cnhubert_base_path = None
20
+
21
+
22
+ class CNHubert(nn.Module):
23
+ def __init__(self):
24
+ super().__init__()
25
+ self.model = HubertModel.from_pretrained(cnhubert_base_path)
26
+ self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
27
+ cnhubert_base_path
28
+ )
29
+
30
+ def forward(self, x):
31
+ input_values = self.feature_extractor(
32
+ x, return_tensors="pt", sampling_rate=16000
33
+ ).input_values.to(x.device)
34
+ feats = self.model(input_values)["last_hidden_state"]
35
+ return feats
36
+
37
+
38
+ # class CNHubertLarge(nn.Module):
39
+ # def __init__(self):
40
+ # super().__init__()
41
+ # self.model = HubertModel.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-hubert-large")
42
+ # self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-hubert-large")
43
+ # def forward(self, x):
44
+ # input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
45
+ # feats = self.model(input_values)["last_hidden_state"]
46
+ # return feats
47
+ #
48
+ # class CVec(nn.Module):
49
+ # def __init__(self):
50
+ # super().__init__()
51
+ # self.model = HubertModel.from_pretrained("/data/docker/liujing04/vc-webui-big/hubert_base")
52
+ # self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/vc-webui-big/hubert_base")
53
+ # def forward(self, x):
54
+ # input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
55
+ # feats = self.model(input_values)["last_hidden_state"]
56
+ # return feats
57
+ #
58
+ # class cnw2v2base(nn.Module):
59
+ # def __init__(self):
60
+ # super().__init__()
61
+ # self.model = Wav2Vec2Model.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-wav2vec2-base")
62
+ # self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-wav2vec2-base")
63
+ # def forward(self, x):
64
+ # input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
65
+ # feats = self.model(input_values)["last_hidden_state"]
66
+ # return feats
67
+
68
+
69
+ def get_model():
70
+ model = CNHubert()
71
+ model.eval()
72
+ return model
73
+
74
+
75
+ # def get_large_model():
76
+ # model = CNHubertLarge()
77
+ # model.eval()
78
+ # return model
79
+ #
80
+ # def get_model_cvec():
81
+ # model = CVec()
82
+ # model.eval()
83
+ # return model
84
+ #
85
+ # def get_model_cnw2v2base():
86
+ # model = cnw2v2base()
87
+ # model.eval()
88
+ # return model
89
+
90
+
91
+ def get_content(hmodel, wav_16k_tensor):
92
+ with torch.no_grad():
93
+ feats = hmodel(wav_16k_tensor)
94
+ return feats.transpose(1, 2)
95
+
96
+
97
+ if __name__ == "__main__":
98
+ model = get_model()
99
+ src_path = "/Users/Shared/原音频2.wav"
100
+ wav_16k_tensor = utils.load_wav_to_torch_and_resample(src_path, 16000)
101
+ model = model
102
+ wav_16k_tensor = wav_16k_tensor
103
+ feats = get_content(model, wav_16k_tensor)
104
+ print(feats.shape)
feature_extractor/whisper_enc.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def get_model():
5
+ import whisper
6
+
7
+ model = whisper.load_model("small", device="cpu")
8
+
9
+ return model.encoder
10
+
11
+
12
+ def get_content(model=None, wav_16k_tensor=None):
13
+ from whisper import log_mel_spectrogram, pad_or_trim
14
+
15
+ dev = next(model.parameters()).device
16
+ mel = log_mel_spectrogram(wav_16k_tensor).to(dev)[:, :3000]
17
+ # if torch.cuda.is_available():
18
+ # mel = mel.to(torch.float16)
19
+ feature_len = mel.shape[-1] // 2
20
+ assert mel.shape[-1] < 3000, "输入音频过长,只允许输入30以内音频"
21
+ with torch.no_grad():
22
+ feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[
23
+ :1, :feature_len, :
24
+ ].transpose(1, 2)
25
+ return feature
inference_webui.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ gpt_path = os.environ.get(
4
+ "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
5
+ )
6
+ sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
7
+ cnhubert_base_path = os.environ.get(
8
+ "cnhubert_base_path", "pretrained_models/chinese-hubert-base"
9
+ )
10
+ bert_path = os.environ.get(
11
+ "bert_path", "pretrained_models/chinese-roberta-wwm-ext-large"
12
+ )
13
+ infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
14
+ infer_ttswebui = int(infer_ttswebui)
15
+ if "_CUDA_VISIBLE_DEVICES" in os.environ:
16
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
17
+ is_half = eval(os.environ.get("is_half", "True"))
18
+ import gradio as gr
19
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
20
+ import numpy as np
21
+ import librosa,torch
22
+ from feature_extractor import cnhubert
23
+ cnhubert.cnhubert_base_path=cnhubert_base_path
24
+
25
+ from module.models import SynthesizerTrn
26
+ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
27
+ from text import cleaned_text_to_sequence
28
+ from text.cleaner import clean_text
29
+ from time import time as ttime
30
+ from module.mel_processing import spectrogram_torch
31
+ from my_utils import load_audio
32
+
33
+ device = "cuda"
34
+ tokenizer = AutoTokenizer.from_pretrained(bert_path)
35
+ bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
36
+ if is_half == True:
37
+ bert_model = bert_model.half().to(device)
38
+ else:
39
+ bert_model = bert_model.to(device)
40
+
41
+
42
+ # bert_model=bert_model.to(device)
43
+ def get_bert_feature(text, word2ph):
44
+ with torch.no_grad():
45
+ inputs = tokenizer(text, return_tensors="pt")
46
+ for i in inputs:
47
+ inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model
48
+ res = bert_model(**inputs, output_hidden_states=True)
49
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
50
+ assert len(word2ph) == len(text)
51
+ phone_level_feature = []
52
+ for i in range(len(word2ph)):
53
+ repeat_feature = res[i].repeat(word2ph[i], 1)
54
+ phone_level_feature.append(repeat_feature)
55
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
56
+ # if(is_half==True):phone_level_feature=phone_level_feature.half()
57
+ return phone_level_feature.T
58
+
59
+
60
+ n_semantic = 1024
61
+
62
+ dict_s2=torch.load(sovits_path,map_location="cpu")
63
+ hps=dict_s2["config"]
64
+
65
+ class DictToAttrRecursive(dict):
66
+ def __init__(self, input_dict):
67
+ super().__init__(input_dict)
68
+ for key, value in input_dict.items():
69
+ if isinstance(value, dict):
70
+ value = DictToAttrRecursive(value)
71
+ self[key] = value
72
+ setattr(self, key, value)
73
+
74
+ def __getattr__(self, item):
75
+ try:
76
+ return self[item]
77
+ except KeyError:
78
+ raise AttributeError(f"Attribute {item} not found")
79
+
80
+ def __setattr__(self, key, value):
81
+ if isinstance(value, dict):
82
+ value = DictToAttrRecursive(value)
83
+ super(DictToAttrRecursive, self).__setitem__(key, value)
84
+ super().__setattr__(key, value)
85
+
86
+ def __delattr__(self, item):
87
+ try:
88
+ del self[item]
89
+ except KeyError:
90
+ raise AttributeError(f"Attribute {item} not found")
91
+
92
+
93
+ hps = DictToAttrRecursive(hps)
94
+
95
+ hps.model.semantic_frame_rate = "25hz"
96
+ dict_s1 = torch.load(gpt_path, map_location="cpu")
97
+ config = dict_s1["config"]
98
+ ssl_model = cnhubert.get_model()
99
+ if is_half == True:
100
+ ssl_model = ssl_model.half().to(device)
101
+ else:
102
+ ssl_model = ssl_model.to(device)
103
+
104
+ vq_model = SynthesizerTrn(
105
+ hps.data.filter_length // 2 + 1,
106
+ hps.train.segment_size // hps.data.hop_length,
107
+ n_speakers=hps.data.n_speakers,
108
+ **hps.model
109
+ )
110
+ if is_half == True:
111
+ vq_model = vq_model.half().to(device)
112
+ else:
113
+ vq_model = vq_model.to(device)
114
+ vq_model.eval()
115
+ print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
116
+ hz = 50
117
+ max_sec = config["data"]["max_sec"]
118
+ # t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
119
+ t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
120
+ t2s_model.load_state_dict(dict_s1["weight"])
121
+ if is_half == True:
122
+ t2s_model = t2s_model.half()
123
+ t2s_model = t2s_model.to(device)
124
+ t2s_model.eval()
125
+ total = sum([param.nelement() for param in t2s_model.parameters()])
126
+ print("Number of parameter: %.2fM" % (total / 1e6))
127
+
128
+
129
+ def get_spepc(hps, filename):
130
+ audio = load_audio(filename, int(hps.data.sampling_rate))
131
+ audio = torch.FloatTensor(audio)
132
+ audio_norm = audio
133
+ audio_norm = audio_norm.unsqueeze(0)
134
+ spec = spectrogram_torch(
135
+ audio_norm,
136
+ hps.data.filter_length,
137
+ hps.data.sampling_rate,
138
+ hps.data.hop_length,
139
+ hps.data.win_length,
140
+ center=False,
141
+ )
142
+ return spec
143
+
144
+
145
+ dict_language = {"中文": "zh", "英文": "en", "日文": "ja"}
146
+
147
+
148
+ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
149
+ t0 = ttime()
150
+ prompt_text = prompt_text.strip("\n")
151
+ prompt_language, text = prompt_language, text.strip("\n")
152
+ with torch.no_grad():
153
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
154
+ wav16k = torch.from_numpy(wav16k)
155
+ if is_half == True:
156
+ wav16k = wav16k.half().to(device)
157
+ else:
158
+ wav16k = wav16k.to(device)
159
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
160
+ "last_hidden_state"
161
+ ].transpose(
162
+ 1, 2
163
+ ) # .float()
164
+ codes = vq_model.extract_latent(ssl_content)
165
+ prompt_semantic = codes[0, 0]
166
+ t1 = ttime()
167
+ prompt_language = dict_language[prompt_language]
168
+ text_language = dict_language[text_language]
169
+ phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
170
+ phones1 = cleaned_text_to_sequence(phones1)
171
+ texts = text.split("\n")
172
+ audio_opt = []
173
+ zero_wav = np.zeros(
174
+ int(hps.data.sampling_rate * 0.3),
175
+ dtype=np.float16 if is_half == True else np.float32,
176
+ )
177
+ for text in texts:
178
+ phones2, word2ph2, norm_text2 = clean_text(text, text_language)
179
+ phones2 = cleaned_text_to_sequence(phones2)
180
+ if prompt_language == "zh":
181
+ bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
182
+ else:
183
+ bert1 = torch.zeros(
184
+ (1024, len(phones1)),
185
+ dtype=torch.float16 if is_half == True else torch.float32,
186
+ ).to(device)
187
+ if text_language == "zh":
188
+ bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
189
+ else:
190
+ bert2 = torch.zeros((1024, len(phones2))).to(bert1)
191
+ bert = torch.cat([bert1, bert2], 1)
192
+
193
+ all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
194
+ bert = bert.to(device).unsqueeze(0)
195
+ all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
196
+ prompt = prompt_semantic.unsqueeze(0).to(device)
197
+ t2 = ttime()
198
+ with torch.no_grad():
199
+ # pred_semantic = t2s_model.model.infer(
200
+ pred_semantic, idx = t2s_model.model.infer_panel(
201
+ all_phoneme_ids,
202
+ all_phoneme_len,
203
+ prompt,
204
+ bert,
205
+ # prompt_phone_len=ph_offset,
206
+ top_k=config["inference"]["top_k"],
207
+ early_stop_num=hz * max_sec,
208
+ )
209
+ t3 = ttime()
210
+ # print(pred_semantic.shape,idx)
211
+ pred_semantic = pred_semantic[:, -idx:].unsqueeze(
212
+ 0
213
+ ) # .unsqueeze(0)#mq要多unsqueeze一次
214
+ refer = get_spepc(hps, ref_wav_path) # .to(device)
215
+ if is_half == True:
216
+ refer = refer.half().to(device)
217
+ else:
218
+ refer = refer.to(device)
219
+ # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
220
+ audio = (
221
+ vq_model.decode(
222
+ pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
223
+ )
224
+ .detach()
225
+ .cpu()
226
+ .numpy()[0, 0]
227
+ ) ###试试重建不带上prompt部分
228
+ audio_opt.append(audio)
229
+ audio_opt.append(zero_wav)
230
+ t4 = ttime()
231
+ print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
232
+ yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
233
+ np.int16
234
+ )
235
+
236
+
237
+ splits = {
238
+ ",",
239
+ "。",
240
+ "?",
241
+ "!",
242
+ ",",
243
+ ".",
244
+ "?",
245
+ "!",
246
+ "~",
247
+ ":",
248
+ ":",
249
+ "—",
250
+ "…",
251
+ } # 不考虑省略号
252
+
253
+
254
+ def split(todo_text):
255
+ todo_text = todo_text.replace("……", "。").replace("——", ",")
256
+ if todo_text[-1] not in splits:
257
+ todo_text += "。"
258
+ i_split_head = i_split_tail = 0
259
+ len_text = len(todo_text)
260
+ todo_texts = []
261
+ while 1:
262
+ if i_split_head >= len_text:
263
+ break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
264
+ if todo_text[i_split_head] in splits:
265
+ i_split_head += 1
266
+ todo_texts.append(todo_text[i_split_tail:i_split_head])
267
+ i_split_tail = i_split_head
268
+ else:
269
+ i_split_head += 1
270
+ return todo_texts
271
+
272
+
273
+ def cut1(inp):
274
+ inp = inp.strip("\n")
275
+ inps = split(inp)
276
+ split_idx = list(range(0, len(inps), 5))
277
+ split_idx[-1] = None
278
+ if len(split_idx) > 1:
279
+ opts = []
280
+ for idx in range(len(split_idx) - 1):
281
+ opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
282
+ else:
283
+ opts = [inp]
284
+ return "\n".join(opts)
285
+
286
+
287
+ def cut2(inp):
288
+ inp = inp.strip("\n")
289
+ inps = split(inp)
290
+ if len(inps) < 2:
291
+ return [inp]
292
+ opts = []
293
+ summ = 0
294
+ tmp_str = ""
295
+ for i in range(len(inps)):
296
+ summ += len(inps[i])
297
+ tmp_str += inps[i]
298
+ if summ > 50:
299
+ summ = 0
300
+ opts.append(tmp_str)
301
+ tmp_str = ""
302
+ if tmp_str != "":
303
+ opts.append(tmp_str)
304
+ if len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
305
+ opts[-2] = opts[-2] + opts[-1]
306
+ opts = opts[:-1]
307
+ return "\n".join(opts)
308
+
309
+
310
+ def cut3(inp):
311
+ inp = inp.strip("\n")
312
+ return "\n".join(["%s。" % item for item in inp.strip("。").split("。")])
313
+
314
+
315
+ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
316
+ gr.Markdown(
317
+ value="本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>."
318
+ )
319
+ # with gr.Tabs():
320
+ # with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
321
+ with gr.Group():
322
+ gr.Markdown(value="*请上传并填写参考信息")
323
+ with gr.Row():
324
+ inp_ref = gr.Audio(label="请上传参考音频", type="filepath")
325
+ prompt_text = gr.Textbox(label="参考音频的文本", value="")
326
+ prompt_language = gr.Dropdown(
327
+ label="参考音频的语种", choices=["中文", "英文", "日文"], value="中文"
328
+ )
329
+ gr.Markdown(value="*请填写需要合成的目标文本")
330
+ with gr.Row():
331
+ text = gr.Textbox(label="需要合成的文本", value="")
332
+ text_language = gr.Dropdown(
333
+ label="需要合成的语种", choices=["中文", "英文", "日文"], value="中文"
334
+ )
335
+ inference_button = gr.Button("合成语音", variant="primary")
336
+ output = gr.Audio(label="输出的语音")
337
+ inference_button.click(
338
+ get_tts_wav,
339
+ [inp_ref, prompt_text, prompt_language, text, text_language],
340
+ [output],
341
+ )
342
+
343
+ gr.Markdown(value="文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")
344
+ with gr.Row():
345
+ text_inp = gr.Textbox(label="需要合成的切分前文本", value="")
346
+ button1 = gr.Button("凑五句一切", variant="primary")
347
+ button2 = gr.Button("凑50字一切", variant="primary")
348
+ button3 = gr.Button("按中文句号。切", variant="primary")
349
+ text_opt = gr.Textbox(label="切分后文本", value="")
350
+ button1.click(cut1, [text_inp], [text_opt])
351
+ button2.click(cut2, [text_inp], [text_opt])
352
+ button3.click(cut3, [text_inp], [text_opt])
353
+ gr.Markdown(value="后续将支持混合语种编码文本输入。")
354
+
355
+ app.queue(concurrency_count=511, max_size=1022).launch(
356
+ server_name="0.0.0.0",
357
+ inbrowser=True,
358
+ server_port=infer_ttswebui,
359
+ quiet=True,
360
+ )
module/__init__.py ADDED
File without changes
module/attentions.py ADDED
@@ -0,0 +1,709 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from module import commons
7
+ from module.modules import LayerNorm
8
+
9
+
10
+ class Encoder(nn.Module):
11
+ def __init__(
12
+ self,
13
+ hidden_channels,
14
+ filter_channels,
15
+ n_heads,
16
+ n_layers,
17
+ kernel_size=1,
18
+ p_dropout=0.0,
19
+ window_size=4,
20
+ isflow=False,
21
+ **kwargs
22
+ ):
23
+ super().__init__()
24
+ self.hidden_channels = hidden_channels
25
+ self.filter_channels = filter_channels
26
+ self.n_heads = n_heads
27
+ self.n_layers = n_layers
28
+ self.kernel_size = kernel_size
29
+ self.p_dropout = p_dropout
30
+ self.window_size = window_size
31
+
32
+ self.drop = nn.Dropout(p_dropout)
33
+ self.attn_layers = nn.ModuleList()
34
+ self.norm_layers_1 = nn.ModuleList()
35
+ self.ffn_layers = nn.ModuleList()
36
+ self.norm_layers_2 = nn.ModuleList()
37
+ for i in range(self.n_layers):
38
+ self.attn_layers.append(
39
+ MultiHeadAttention(
40
+ hidden_channels,
41
+ hidden_channels,
42
+ n_heads,
43
+ p_dropout=p_dropout,
44
+ window_size=window_size,
45
+ )
46
+ )
47
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
48
+ self.ffn_layers.append(
49
+ FFN(
50
+ hidden_channels,
51
+ hidden_channels,
52
+ filter_channels,
53
+ kernel_size,
54
+ p_dropout=p_dropout,
55
+ )
56
+ )
57
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
58
+ if isflow:
59
+ cond_layer = torch.nn.Conv1d(
60
+ kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
61
+ )
62
+ self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
63
+ self.cond_layer = weight_norm_modules(cond_layer, name="weight")
64
+ self.gin_channels = kwargs["gin_channels"]
65
+
66
+ def forward(self, x, x_mask, g=None):
67
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
68
+ x = x * x_mask
69
+ if g is not None:
70
+ g = self.cond_layer(g)
71
+
72
+ for i in range(self.n_layers):
73
+ if g is not None:
74
+ x = self.cond_pre(x)
75
+ cond_offset = i * 2 * self.hidden_channels
76
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
77
+ x = commons.fused_add_tanh_sigmoid_multiply(
78
+ x, g_l, torch.IntTensor([self.hidden_channels])
79
+ )
80
+ y = self.attn_layers[i](x, x, attn_mask)
81
+ y = self.drop(y)
82
+ x = self.norm_layers_1[i](x + y)
83
+
84
+ y = self.ffn_layers[i](x, x_mask)
85
+ y = self.drop(y)
86
+ x = self.norm_layers_2[i](x + y)
87
+ x = x * x_mask
88
+ return x
89
+
90
+
91
+ class Decoder(nn.Module):
92
+ def __init__(
93
+ self,
94
+ hidden_channels,
95
+ filter_channels,
96
+ n_heads,
97
+ n_layers,
98
+ kernel_size=1,
99
+ p_dropout=0.0,
100
+ proximal_bias=False,
101
+ proximal_init=True,
102
+ **kwargs
103
+ ):
104
+ super().__init__()
105
+ self.hidden_channels = hidden_channels
106
+ self.filter_channels = filter_channels
107
+ self.n_heads = n_heads
108
+ self.n_layers = n_layers
109
+ self.kernel_size = kernel_size
110
+ self.p_dropout = p_dropout
111
+ self.proximal_bias = proximal_bias
112
+ self.proximal_init = proximal_init
113
+
114
+ self.drop = nn.Dropout(p_dropout)
115
+ self.self_attn_layers = nn.ModuleList()
116
+ self.norm_layers_0 = nn.ModuleList()
117
+ self.encdec_attn_layers = nn.ModuleList()
118
+ self.norm_layers_1 = nn.ModuleList()
119
+ self.ffn_layers = nn.ModuleList()
120
+ self.norm_layers_2 = nn.ModuleList()
121
+ for i in range(self.n_layers):
122
+ self.self_attn_layers.append(
123
+ MultiHeadAttention(
124
+ hidden_channels,
125
+ hidden_channels,
126
+ n_heads,
127
+ p_dropout=p_dropout,
128
+ proximal_bias=proximal_bias,
129
+ proximal_init=proximal_init,
130
+ )
131
+ )
132
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
133
+ self.encdec_attn_layers.append(
134
+ MultiHeadAttention(
135
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
136
+ )
137
+ )
138
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
139
+ self.ffn_layers.append(
140
+ FFN(
141
+ hidden_channels,
142
+ hidden_channels,
143
+ filter_channels,
144
+ kernel_size,
145
+ p_dropout=p_dropout,
146
+ causal=True,
147
+ )
148
+ )
149
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
150
+
151
+ def forward(self, x, x_mask, h, h_mask):
152
+ """
153
+ x: decoder input
154
+ h: encoder output
155
+ """
156
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
157
+ device=x.device, dtype=x.dtype
158
+ )
159
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
160
+ x = x * x_mask
161
+ for i in range(self.n_layers):
162
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
163
+ y = self.drop(y)
164
+ x = self.norm_layers_0[i](x + y)
165
+
166
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
167
+ y = self.drop(y)
168
+ x = self.norm_layers_1[i](x + y)
169
+
170
+ y = self.ffn_layers[i](x, x_mask)
171
+ y = self.drop(y)
172
+ x = self.norm_layers_2[i](x + y)
173
+ x = x * x_mask
174
+ return x
175
+
176
+
177
+ class MultiHeadAttention(nn.Module):
178
+ def __init__(
179
+ self,
180
+ channels,
181
+ out_channels,
182
+ n_heads,
183
+ p_dropout=0.0,
184
+ window_size=None,
185
+ heads_share=True,
186
+ block_length=None,
187
+ proximal_bias=False,
188
+ proximal_init=False,
189
+ ):
190
+ super().__init__()
191
+ assert channels % n_heads == 0
192
+
193
+ self.channels = channels
194
+ self.out_channels = out_channels
195
+ self.n_heads = n_heads
196
+ self.p_dropout = p_dropout
197
+ self.window_size = window_size
198
+ self.heads_share = heads_share
199
+ self.block_length = block_length
200
+ self.proximal_bias = proximal_bias
201
+ self.proximal_init = proximal_init
202
+ self.attn = None
203
+
204
+ self.k_channels = channels // n_heads
205
+ self.conv_q = nn.Conv1d(channels, channels, 1)
206
+ self.conv_k = nn.Conv1d(channels, channels, 1)
207
+ self.conv_v = nn.Conv1d(channels, channels, 1)
208
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
209
+ self.drop = nn.Dropout(p_dropout)
210
+
211
+ if window_size is not None:
212
+ n_heads_rel = 1 if heads_share else n_heads
213
+ rel_stddev = self.k_channels**-0.5
214
+ self.emb_rel_k = nn.Parameter(
215
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
216
+ * rel_stddev
217
+ )
218
+ self.emb_rel_v = nn.Parameter(
219
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
220
+ * rel_stddev
221
+ )
222
+
223
+ nn.init.xavier_uniform_(self.conv_q.weight)
224
+ nn.init.xavier_uniform_(self.conv_k.weight)
225
+ nn.init.xavier_uniform_(self.conv_v.weight)
226
+ if proximal_init:
227
+ with torch.no_grad():
228
+ self.conv_k.weight.copy_(self.conv_q.weight)
229
+ self.conv_k.bias.copy_(self.conv_q.bias)
230
+
231
+ def forward(self, x, c, attn_mask=None):
232
+ q = self.conv_q(x)
233
+ k = self.conv_k(c)
234
+ v = self.conv_v(c)
235
+
236
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
237
+
238
+ x = self.conv_o(x)
239
+ return x
240
+
241
+ def attention(self, query, key, value, mask=None):
242
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
243
+ b, d, t_s, t_t = (*key.size(), query.size(2))
244
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
245
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
246
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
247
+
248
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
249
+ if self.window_size is not None:
250
+ assert (
251
+ t_s == t_t
252
+ ), "Relative attention is only available for self-attention."
253
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
254
+ rel_logits = self._matmul_with_relative_keys(
255
+ query / math.sqrt(self.k_channels), key_relative_embeddings
256
+ )
257
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
258
+ scores = scores + scores_local
259
+ if self.proximal_bias:
260
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
261
+ scores = scores + self._attention_bias_proximal(t_s).to(
262
+ device=scores.device, dtype=scores.dtype
263
+ )
264
+ if mask is not None:
265
+ scores = scores.masked_fill(mask == 0, -1e4)
266
+ if self.block_length is not None:
267
+ assert (
268
+ t_s == t_t
269
+ ), "Local attention is only available for self-attention."
270
+ block_mask = (
271
+ torch.ones_like(scores)
272
+ .triu(-self.block_length)
273
+ .tril(self.block_length)
274
+ )
275
+ scores = scores.masked_fill(block_mask == 0, -1e4)
276
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
277
+ p_attn = self.drop(p_attn)
278
+ output = torch.matmul(p_attn, value)
279
+ if self.window_size is not None:
280
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
281
+ value_relative_embeddings = self._get_relative_embeddings(
282
+ self.emb_rel_v, t_s
283
+ )
284
+ output = output + self._matmul_with_relative_values(
285
+ relative_weights, value_relative_embeddings
286
+ )
287
+ output = (
288
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
289
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
290
+ return output, p_attn
291
+
292
+ def _matmul_with_relative_values(self, x, y):
293
+ """
294
+ x: [b, h, l, m]
295
+ y: [h or 1, m, d]
296
+ ret: [b, h, l, d]
297
+ """
298
+ ret = torch.matmul(x, y.unsqueeze(0))
299
+ return ret
300
+
301
+ def _matmul_with_relative_keys(self, x, y):
302
+ """
303
+ x: [b, h, l, d]
304
+ y: [h or 1, m, d]
305
+ ret: [b, h, l, m]
306
+ """
307
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
308
+ return ret
309
+
310
+ def _get_relative_embeddings(self, relative_embeddings, length):
311
+ max_relative_position = 2 * self.window_size + 1
312
+ # Pad first before slice to avoid using cond ops.
313
+ pad_length = max(length - (self.window_size + 1), 0)
314
+ slice_start_position = max((self.window_size + 1) - length, 0)
315
+ slice_end_position = slice_start_position + 2 * length - 1
316
+ if pad_length > 0:
317
+ padded_relative_embeddings = F.pad(
318
+ relative_embeddings,
319
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
320
+ )
321
+ else:
322
+ padded_relative_embeddings = relative_embeddings
323
+ used_relative_embeddings = padded_relative_embeddings[
324
+ :, slice_start_position:slice_end_position
325
+ ]
326
+ return used_relative_embeddings
327
+
328
+ def _relative_position_to_absolute_position(self, x):
329
+ """
330
+ x: [b, h, l, 2*l-1]
331
+ ret: [b, h, l, l]
332
+ """
333
+ batch, heads, length, _ = x.size()
334
+ # Concat columns of pad to shift from relative to absolute indexing.
335
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
336
+
337
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
338
+ x_flat = x.view([batch, heads, length * 2 * length])
339
+ x_flat = F.pad(
340
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
341
+ )
342
+
343
+ # Reshape and slice out the padded elements.
344
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
345
+ :, :, :length, length - 1 :
346
+ ]
347
+ return x_final
348
+
349
+ def _absolute_position_to_relative_position(self, x):
350
+ """
351
+ x: [b, h, l, l]
352
+ ret: [b, h, l, 2*l-1]
353
+ """
354
+ batch, heads, length, _ = x.size()
355
+ # padd along column
356
+ x = F.pad(
357
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
358
+ )
359
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
360
+ # add 0's in the beginning that will skew the elements after reshape
361
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
362
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
363
+ return x_final
364
+
365
+ def _attention_bias_proximal(self, length):
366
+ """Bias for self-attention to encourage attention to close positions.
367
+ Args:
368
+ length: an integer scalar.
369
+ Returns:
370
+ a Tensor with shape [1, 1, length, length]
371
+ """
372
+ r = torch.arange(length, dtype=torch.float32)
373
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
374
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
375
+
376
+
377
+ class FFN(nn.Module):
378
+ def __init__(
379
+ self,
380
+ in_channels,
381
+ out_channels,
382
+ filter_channels,
383
+ kernel_size,
384
+ p_dropout=0.0,
385
+ activation=None,
386
+ causal=False,
387
+ ):
388
+ super().__init__()
389
+ self.in_channels = in_channels
390
+ self.out_channels = out_channels
391
+ self.filter_channels = filter_channels
392
+ self.kernel_size = kernel_size
393
+ self.p_dropout = p_dropout
394
+ self.activation = activation
395
+ self.causal = causal
396
+
397
+ if causal:
398
+ self.padding = self._causal_padding
399
+ else:
400
+ self.padding = self._same_padding
401
+
402
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
403
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
404
+ self.drop = nn.Dropout(p_dropout)
405
+
406
+ def forward(self, x, x_mask):
407
+ x = self.conv_1(self.padding(x * x_mask))
408
+ if self.activation == "gelu":
409
+ x = x * torch.sigmoid(1.702 * x)
410
+ else:
411
+ x = torch.relu(x)
412
+ x = self.drop(x)
413
+ x = self.conv_2(self.padding(x * x_mask))
414
+ return x * x_mask
415
+
416
+ def _causal_padding(self, x):
417
+ if self.kernel_size == 1:
418
+ return x
419
+ pad_l = self.kernel_size - 1
420
+ pad_r = 0
421
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
422
+ x = F.pad(x, commons.convert_pad_shape(padding))
423
+ return x
424
+
425
+ def _same_padding(self, x):
426
+ if self.kernel_size == 1:
427
+ return x
428
+ pad_l = (self.kernel_size - 1) // 2
429
+ pad_r = self.kernel_size // 2
430
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
431
+ x = F.pad(x, commons.convert_pad_shape(padding))
432
+ return x
433
+
434
+
435
+ import torch.nn as nn
436
+ from torch.nn.utils import remove_weight_norm, weight_norm
437
+
438
+
439
+ class Depthwise_Separable_Conv1D(nn.Module):
440
+ def __init__(
441
+ self,
442
+ in_channels,
443
+ out_channels,
444
+ kernel_size,
445
+ stride=1,
446
+ padding=0,
447
+ dilation=1,
448
+ bias=True,
449
+ padding_mode="zeros", # TODO: refine this type
450
+ device=None,
451
+ dtype=None,
452
+ ):
453
+ super().__init__()
454
+ self.depth_conv = nn.Conv1d(
455
+ in_channels=in_channels,
456
+ out_channels=in_channels,
457
+ kernel_size=kernel_size,
458
+ groups=in_channels,
459
+ stride=stride,
460
+ padding=padding,
461
+ dilation=dilation,
462
+ bias=bias,
463
+ padding_mode=padding_mode,
464
+ device=device,
465
+ dtype=dtype,
466
+ )
467
+ self.point_conv = nn.Conv1d(
468
+ in_channels=in_channels,
469
+ out_channels=out_channels,
470
+ kernel_size=1,
471
+ bias=bias,
472
+ device=device,
473
+ dtype=dtype,
474
+ )
475
+
476
+ def forward(self, input):
477
+ return self.point_conv(self.depth_conv(input))
478
+
479
+ def weight_norm(self):
480
+ self.depth_conv = weight_norm(self.depth_conv, name="weight")
481
+ self.point_conv = weight_norm(self.point_conv, name="weight")
482
+
483
+ def remove_weight_norm(self):
484
+ self.depth_conv = remove_weight_norm(self.depth_conv, name="weight")
485
+ self.point_conv = remove_weight_norm(self.point_conv, name="weight")
486
+
487
+
488
+ class Depthwise_Separable_TransposeConv1D(nn.Module):
489
+ def __init__(
490
+ self,
491
+ in_channels,
492
+ out_channels,
493
+ kernel_size,
494
+ stride=1,
495
+ padding=0,
496
+ output_padding=0,
497
+ bias=True,
498
+ dilation=1,
499
+ padding_mode="zeros", # TODO: refine this type
500
+ device=None,
501
+ dtype=None,
502
+ ):
503
+ super().__init__()
504
+ self.depth_conv = nn.ConvTranspose1d(
505
+ in_channels=in_channels,
506
+ out_channels=in_channels,
507
+ kernel_size=kernel_size,
508
+ groups=in_channels,
509
+ stride=stride,
510
+ output_padding=output_padding,
511
+ padding=padding,
512
+ dilation=dilation,
513
+ bias=bias,
514
+ padding_mode=padding_mode,
515
+ device=device,
516
+ dtype=dtype,
517
+ )
518
+ self.point_conv = nn.Conv1d(
519
+ in_channels=in_channels,
520
+ out_channels=out_channels,
521
+ kernel_size=1,
522
+ bias=bias,
523
+ device=device,
524
+ dtype=dtype,
525
+ )
526
+
527
+ def forward(self, input):
528
+ return self.point_conv(self.depth_conv(input))
529
+
530
+ def weight_norm(self):
531
+ self.depth_conv = weight_norm(self.depth_conv, name="weight")
532
+ self.point_conv = weight_norm(self.point_conv, name="weight")
533
+
534
+ def remove_weight_norm(self):
535
+ remove_weight_norm(self.depth_conv, name="weight")
536
+ remove_weight_norm(self.point_conv, name="weight")
537
+
538
+
539
+ def weight_norm_modules(module, name="weight", dim=0):
540
+ if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
541
+ module, Depthwise_Separable_TransposeConv1D
542
+ ):
543
+ module.weight_norm()
544
+ return module
545
+ else:
546
+ return weight_norm(module, name, dim)
547
+
548
+
549
+ def remove_weight_norm_modules(module, name="weight"):
550
+ if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
551
+ module, Depthwise_Separable_TransposeConv1D
552
+ ):
553
+ module.remove_weight_norm()
554
+ else:
555
+ remove_weight_norm(module, name)
556
+
557
+
558
+ class FFT(nn.Module):
559
+ def __init__(
560
+ self,
561
+ hidden_channels,
562
+ filter_channels,
563
+ n_heads,
564
+ n_layers=1,
565
+ kernel_size=1,
566
+ p_dropout=0.0,
567
+ proximal_bias=False,
568
+ proximal_init=True,
569
+ isflow=False,
570
+ **kwargs
571
+ ):
572
+ super().__init__()
573
+ self.hidden_channels = hidden_channels
574
+ self.filter_channels = filter_channels
575
+ self.n_heads = n_heads
576
+ self.n_layers = n_layers
577
+ self.kernel_size = kernel_size
578
+ self.p_dropout = p_dropout
579
+ self.proximal_bias = proximal_bias
580
+ self.proximal_init = proximal_init
581
+ if isflow:
582
+ cond_layer = torch.nn.Conv1d(
583
+ kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
584
+ )
585
+ self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
586
+ self.cond_layer = weight_norm_modules(cond_layer, name="weight")
587
+ self.gin_channels = kwargs["gin_channels"]
588
+ self.drop = nn.Dropout(p_dropout)
589
+ self.self_attn_layers = nn.ModuleList()
590
+ self.norm_layers_0 = nn.ModuleList()
591
+ self.ffn_layers = nn.ModuleList()
592
+ self.norm_layers_1 = nn.ModuleList()
593
+ for i in range(self.n_layers):
594
+ self.self_attn_layers.append(
595
+ MultiHeadAttention(
596
+ hidden_channels,
597
+ hidden_channels,
598
+ n_heads,
599
+ p_dropout=p_dropout,
600
+ proximal_bias=proximal_bias,
601
+ proximal_init=proximal_init,
602
+ )
603
+ )
604
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
605
+ self.ffn_layers.append(
606
+ FFN(
607
+ hidden_channels,
608
+ hidden_channels,
609
+ filter_channels,
610
+ kernel_size,
611
+ p_dropout=p_dropout,
612
+ causal=True,
613
+ )
614
+ )
615
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
616
+
617
+ def forward(self, x, x_mask, g=None):
618
+ """
619
+ x: decoder input
620
+ h: encoder output
621
+ """
622
+ if g is not None:
623
+ g = self.cond_layer(g)
624
+
625
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
626
+ device=x.device, dtype=x.dtype
627
+ )
628
+ x = x * x_mask
629
+ for i in range(self.n_layers):
630
+ if g is not None:
631
+ x = self.cond_pre(x)
632
+ cond_offset = i * 2 * self.hidden_channels
633
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
634
+ x = commons.fused_add_tanh_sigmoid_multiply(
635
+ x, g_l, torch.IntTensor([self.hidden_channels])
636
+ )
637
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
638
+ y = self.drop(y)
639
+ x = self.norm_layers_0[i](x + y)
640
+
641
+ y = self.ffn_layers[i](x, x_mask)
642
+ y = self.drop(y)
643
+ x = self.norm_layers_1[i](x + y)
644
+ x = x * x_mask
645
+ return x
646
+
647
+
648
+ class TransformerCouplingLayer(nn.Module):
649
+ def __init__(
650
+ self,
651
+ channels,
652
+ hidden_channels,
653
+ kernel_size,
654
+ n_layers,
655
+ n_heads,
656
+ p_dropout=0,
657
+ filter_channels=0,
658
+ mean_only=False,
659
+ wn_sharing_parameter=None,
660
+ gin_channels=0,
661
+ ):
662
+ assert channels % 2 == 0, "channels should be divisible by 2"
663
+ super().__init__()
664
+ self.channels = channels
665
+ self.hidden_channels = hidden_channels
666
+ self.kernel_size = kernel_size
667
+ self.n_layers = n_layers
668
+ self.half_channels = channels // 2
669
+ self.mean_only = mean_only
670
+
671
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
672
+ self.enc = (
673
+ Encoder(
674
+ hidden_channels,
675
+ filter_channels,
676
+ n_heads,
677
+ n_layers,
678
+ kernel_size,
679
+ p_dropout,
680
+ isflow=True,
681
+ gin_channels=gin_channels,
682
+ )
683
+ if wn_sharing_parameter is None
684
+ else wn_sharing_parameter
685
+ )
686
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
687
+ self.post.weight.data.zero_()
688
+ self.post.bias.data.zero_()
689
+
690
+ def forward(self, x, x_mask, g=None, reverse=False):
691
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
692
+ h = self.pre(x0) * x_mask
693
+ h = self.enc(h, x_mask, g=g)
694
+ stats = self.post(h) * x_mask
695
+ if not self.mean_only:
696
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
697
+ else:
698
+ m = stats
699
+ logs = torch.zeros_like(m)
700
+
701
+ if not reverse:
702
+ x1 = m + x1 * torch.exp(logs) * x_mask
703
+ x = torch.cat([x0, x1], 1)
704
+ logdet = torch.sum(logs, [1, 2])
705
+ return x, logdet
706
+ else:
707
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
708
+ x = torch.cat([x0, x1], 1)
709
+ return x
module/commons.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def init_weights(m, mean=0.0, std=0.01):
7
+ classname = m.__class__.__name__
8
+ if classname.find("Conv") != -1:
9
+ m.weight.data.normal_(mean, std)
10
+
11
+
12
+ def get_padding(kernel_size, dilation=1):
13
+ return int((kernel_size * dilation - dilation) / 2)
14
+
15
+
16
+ def convert_pad_shape(pad_shape):
17
+ l = pad_shape[::-1]
18
+ pad_shape = [item for sublist in l for item in sublist]
19
+ return pad_shape
20
+
21
+
22
+ def intersperse(lst, item):
23
+ result = [item] * (len(lst) * 2 + 1)
24
+ result[1::2] = lst
25
+ return result
26
+
27
+
28
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
29
+ """KL(P||Q)"""
30
+ kl = (logs_q - logs_p) - 0.5
31
+ kl += (
32
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
33
+ )
34
+ return kl
35
+
36
+
37
+ def rand_gumbel(shape):
38
+ """Sample from the Gumbel distribution, protect from overflows."""
39
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40
+ return -torch.log(-torch.log(uniform_samples))
41
+
42
+
43
+ def rand_gumbel_like(x):
44
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45
+ return g
46
+
47
+
48
+ def slice_segments(x, ids_str, segment_size=4):
49
+ ret = torch.zeros_like(x[:, :, :segment_size])
50
+ for i in range(x.size(0)):
51
+ idx_str = ids_str[i]
52
+ idx_end = idx_str + segment_size
53
+ ret[i] = x[i, :, idx_str:idx_end]
54
+ return ret
55
+
56
+
57
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
58
+ b, d, t = x.size()
59
+ if x_lengths is None:
60
+ x_lengths = t
61
+ ids_str_max = x_lengths - segment_size + 1
62
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
63
+ ret = slice_segments(x, ids_str, segment_size)
64
+ return ret, ids_str
65
+
66
+
67
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
68
+ position = torch.arange(length, dtype=torch.float)
69
+ num_timescales = channels // 2
70
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
71
+ num_timescales - 1
72
+ )
73
+ inv_timescales = min_timescale * torch.exp(
74
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
75
+ )
76
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
77
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
78
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
79
+ signal = signal.view(1, channels, length)
80
+ return signal
81
+
82
+
83
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
84
+ b, channels, length = x.size()
85
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
86
+ return x + signal.to(dtype=x.dtype, device=x.device)
87
+
88
+
89
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
90
+ b, channels, length = x.size()
91
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
92
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
93
+
94
+
95
+ def subsequent_mask(length):
96
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
97
+ return mask
98
+
99
+
100
+ @torch.jit.script
101
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
102
+ n_channels_int = n_channels[0]
103
+ in_act = input_a + input_b
104
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
105
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
106
+ acts = t_act * s_act
107
+ return acts
108
+
109
+
110
+ def convert_pad_shape(pad_shape):
111
+ l = pad_shape[::-1]
112
+ pad_shape = [item for sublist in l for item in sublist]
113
+ return pad_shape
114
+
115
+
116
+ def shift_1d(x):
117
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
118
+ return x
119
+
120
+
121
+ def sequence_mask(length, max_length=None):
122
+ if max_length is None:
123
+ max_length = length.max()
124
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
125
+ return x.unsqueeze(0) < length.unsqueeze(1)
126
+
127
+
128
+ def generate_path(duration, mask):
129
+ """
130
+ duration: [b, 1, t_x]
131
+ mask: [b, 1, t_y, t_x]
132
+ """
133
+ device = duration.device
134
+
135
+ b, _, t_y, t_x = mask.shape
136
+ cum_duration = torch.cumsum(duration, -1)
137
+
138
+ cum_duration_flat = cum_duration.view(b * t_x)
139
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
140
+ path = path.view(b, t_x, t_y)
141
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
142
+ path = path.unsqueeze(1).transpose(2, 3) * mask
143
+ return path
144
+
145
+
146
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
147
+ if isinstance(parameters, torch.Tensor):
148
+ parameters = [parameters]
149
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
150
+ norm_type = float(norm_type)
151
+ if clip_value is not None:
152
+ clip_value = float(clip_value)
153
+
154
+ total_norm = 0
155
+ for p in parameters:
156
+ param_norm = p.grad.data.norm(norm_type)
157
+ total_norm += param_norm.item() ** norm_type
158
+ if clip_value is not None:
159
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
160
+ total_norm = total_norm ** (1.0 / norm_type)
161
+ return total_norm
162
+
163
+
164
+ def squeeze(x, x_mask=None, n_sqz=2):
165
+ b, c, t = x.size()
166
+
167
+ t = (t // n_sqz) * n_sqz
168
+ x = x[:, :, :t]
169
+ x_sqz = x.view(b, c, t // n_sqz, n_sqz)
170
+ x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
171
+
172
+ if x_mask is not None:
173
+ x_mask = x_mask[:, :, n_sqz - 1 :: n_sqz]
174
+ else:
175
+ x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
176
+ return x_sqz * x_mask, x_mask
177
+
178
+
179
+ def unsqueeze(x, x_mask=None, n_sqz=2):
180
+ b, c, t = x.size()
181
+
182
+ x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
183
+ x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
184
+
185
+ if x_mask is not None:
186
+ x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
187
+ else:
188
+ x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
189
+ return x_unsqz * x_mask, x_mask
module/core_vq.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # This implementation is inspired from
8
+ # https://github.com/lucidrains/vector-quantize-pytorch
9
+ # which is released under MIT License. Hereafter, the original license:
10
+ # MIT License
11
+ #
12
+ # Copyright (c) 2020 Phil Wang
13
+ #
14
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ # of this software and associated documentation files (the "Software"), to deal
16
+ # in the Software without restriction, including without limitation the rights
17
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ # copies of the Software, and to permit persons to whom the Software is
19
+ # furnished to do so, subject to the following conditions:
20
+ #
21
+ # The above copyright notice and this permission notice shall be included in all
22
+ # copies or substantial portions of the Software.
23
+ #
24
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
+ # SOFTWARE.
31
+
32
+ """Core vector quantization implementation."""
33
+ import typing as tp
34
+
35
+ from einops import rearrange, repeat
36
+ import torch
37
+ from torch import nn
38
+ import torch.nn.functional as F
39
+ from tqdm import tqdm
40
+
41
+
42
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
43
+ return val if val is not None else d
44
+
45
+
46
+ def ema_inplace(moving_avg, new, decay: float):
47
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
48
+
49
+
50
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
51
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
52
+
53
+
54
+ def uniform_init(*shape: int):
55
+ t = torch.empty(shape)
56
+ nn.init.kaiming_uniform_(t)
57
+ return t
58
+
59
+
60
+ def sample_vectors(samples, num: int):
61
+ num_samples, device = samples.shape[0], samples.device
62
+
63
+ if num_samples >= num:
64
+ indices = torch.randperm(num_samples, device=device)[:num]
65
+ else:
66
+ indices = torch.randint(0, num_samples, (num,), device=device)
67
+
68
+ return samples[indices]
69
+
70
+
71
+ def kmeans(samples, num_clusters: int, num_iters: int = 10):
72
+ dim, dtype = samples.shape[-1], samples.dtype
73
+ max_kmeans_samples = 500
74
+ samples = samples[:max_kmeans_samples, :]
75
+ means = sample_vectors(samples, num_clusters)
76
+
77
+ print("kmeans start ... ")
78
+ for _ in tqdm(range(num_iters)):
79
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
80
+ dists = -(diffs**2).sum(dim=-1)
81
+
82
+ buckets = dists.max(dim=-1).indices
83
+ bins = torch.bincount(buckets, minlength=num_clusters)
84
+ zero_mask = bins == 0
85
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
86
+
87
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
88
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
89
+ new_means = new_means / bins_min_clamped[..., None]
90
+
91
+ means = torch.where(zero_mask[..., None], means, new_means)
92
+
93
+ return means, bins
94
+
95
+
96
+ class EuclideanCodebook(nn.Module):
97
+ """Codebook with Euclidean distance.
98
+ Args:
99
+ dim (int): Dimension.
100
+ codebook_size (int): Codebook size.
101
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
102
+ If set to true, run the k-means algorithm on the first training batch and use
103
+ the learned centroids as initialization.
104
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
105
+ decay (float): Decay for exponential moving average over the codebooks.
106
+ epsilon (float): Epsilon value for numerical stability.
107
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
108
+ that have an exponential moving average cluster size less than the specified threshold with
109
+ randomly selected vector from the current batch.
110
+ """
111
+
112
+ def __init__(
113
+ self,
114
+ dim: int,
115
+ codebook_size: int,
116
+ kmeans_init: int = False,
117
+ kmeans_iters: int = 10,
118
+ decay: float = 0.99,
119
+ epsilon: float = 1e-5,
120
+ threshold_ema_dead_code: int = 2,
121
+ ):
122
+ super().__init__()
123
+ self.decay = decay
124
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
125
+ uniform_init if not kmeans_init else torch.zeros
126
+ )
127
+ embed = init_fn(codebook_size, dim)
128
+
129
+ self.codebook_size = codebook_size
130
+
131
+ self.kmeans_iters = kmeans_iters
132
+ self.epsilon = epsilon
133
+ self.threshold_ema_dead_code = threshold_ema_dead_code
134
+
135
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
136
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
137
+ self.register_buffer("embed", embed)
138
+ self.register_buffer("embed_avg", embed.clone())
139
+
140
+ @torch.jit.ignore
141
+ def init_embed_(self, data):
142
+ if self.inited:
143
+ return
144
+
145
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
146
+ self.embed.data.copy_(embed)
147
+ self.embed_avg.data.copy_(embed.clone())
148
+ self.cluster_size.data.copy_(cluster_size)
149
+ self.inited.data.copy_(torch.Tensor([True]))
150
+ # Make sure all buffers across workers are in sync after initialization
151
+ # broadcast_tensors(self.buffers())
152
+
153
+ def replace_(self, samples, mask):
154
+ modified_codebook = torch.where(
155
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
156
+ )
157
+ self.embed.data.copy_(modified_codebook)
158
+
159
+ def expire_codes_(self, batch_samples):
160
+ if self.threshold_ema_dead_code == 0:
161
+ return
162
+
163
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
164
+ if not torch.any(expired_codes):
165
+ return
166
+
167
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
168
+ self.replace_(batch_samples, mask=expired_codes)
169
+ # broadcast_tensors(self.buffers())
170
+
171
+ def preprocess(self, x):
172
+ x = rearrange(x, "... d -> (...) d")
173
+ return x
174
+
175
+ def quantize(self, x):
176
+ embed = self.embed.t()
177
+ dist = -(
178
+ x.pow(2).sum(1, keepdim=True)
179
+ - 2 * x @ embed
180
+ + embed.pow(2).sum(0, keepdim=True)
181
+ )
182
+ embed_ind = dist.max(dim=-1).indices
183
+ return embed_ind
184
+
185
+ def postprocess_emb(self, embed_ind, shape):
186
+ return embed_ind.view(*shape[:-1])
187
+
188
+ def dequantize(self, embed_ind):
189
+ quantize = F.embedding(embed_ind, self.embed)
190
+ return quantize
191
+
192
+ def encode(self, x):
193
+ shape = x.shape
194
+ # pre-process
195
+ x = self.preprocess(x)
196
+ # quantize
197
+ embed_ind = self.quantize(x)
198
+ # post-process
199
+ embed_ind = self.postprocess_emb(embed_ind, shape)
200
+ return embed_ind
201
+
202
+ def decode(self, embed_ind):
203
+ quantize = self.dequantize(embed_ind)
204
+ return quantize
205
+
206
+ def forward(self, x):
207
+ shape, dtype = x.shape, x.dtype
208
+ x = self.preprocess(x)
209
+
210
+ self.init_embed_(x)
211
+
212
+ embed_ind = self.quantize(x)
213
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
214
+ embed_ind = self.postprocess_emb(embed_ind, shape)
215
+ quantize = self.dequantize(embed_ind)
216
+
217
+ if self.training:
218
+ # We do the expiry of code at that point as buffers are in sync
219
+ # and all the workers will take the same decision.
220
+ self.expire_codes_(x)
221
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
222
+ embed_sum = x.t() @ embed_onehot
223
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
224
+ cluster_size = (
225
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
226
+ * self.cluster_size.sum()
227
+ )
228
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
229
+ self.embed.data.copy_(embed_normalized)
230
+
231
+ return quantize, embed_ind
232
+
233
+
234
+ class VectorQuantization(nn.Module):
235
+ """Vector quantization implementation.
236
+ Currently supports only euclidean distance.
237
+ Args:
238
+ dim (int): Dimension
239
+ codebook_size (int): Codebook size
240
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
241
+ decay (float): Decay for exponential moving average over the codebooks.
242
+ epsilon (float): Epsilon value for numerical stability.
243
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
244
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
245
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
246
+ that have an exponential moving average cluster size less than the specified threshold with
247
+ randomly selected vector from the current batch.
248
+ commitment_weight (float): Weight for commitment loss.
249
+ """
250
+
251
+ def __init__(
252
+ self,
253
+ dim: int,
254
+ codebook_size: int,
255
+ codebook_dim: tp.Optional[int] = None,
256
+ decay: float = 0.99,
257
+ epsilon: float = 1e-5,
258
+ kmeans_init: bool = True,
259
+ kmeans_iters: int = 50,
260
+ threshold_ema_dead_code: int = 2,
261
+ commitment_weight: float = 1.0,
262
+ ):
263
+ super().__init__()
264
+ _codebook_dim: int = default(codebook_dim, dim)
265
+
266
+ requires_projection = _codebook_dim != dim
267
+ self.project_in = (
268
+ nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
269
+ )
270
+ self.project_out = (
271
+ nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
272
+ )
273
+
274
+ self.epsilon = epsilon
275
+ self.commitment_weight = commitment_weight
276
+
277
+ self._codebook = EuclideanCodebook(
278
+ dim=_codebook_dim,
279
+ codebook_size=codebook_size,
280
+ kmeans_init=kmeans_init,
281
+ kmeans_iters=kmeans_iters,
282
+ decay=decay,
283
+ epsilon=epsilon,
284
+ threshold_ema_dead_code=threshold_ema_dead_code,
285
+ )
286
+ self.codebook_size = codebook_size
287
+
288
+ @property
289
+ def codebook(self):
290
+ return self._codebook.embed
291
+
292
+ def encode(self, x):
293
+ x = rearrange(x, "b d n -> b n d")
294
+ x = self.project_in(x)
295
+ embed_in = self._codebook.encode(x)
296
+ return embed_in
297
+
298
+ def decode(self, embed_ind):
299
+ quantize = self._codebook.decode(embed_ind)
300
+ quantize = self.project_out(quantize)
301
+ quantize = rearrange(quantize, "b n d -> b d n")
302
+ return quantize
303
+
304
+ def forward(self, x):
305
+ device = x.device
306
+ x = rearrange(x, "b d n -> b n d")
307
+ x = self.project_in(x)
308
+
309
+ quantize, embed_ind = self._codebook(x)
310
+
311
+ if self.training:
312
+ quantize = x + (quantize - x).detach()
313
+
314
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
315
+
316
+ if self.training:
317
+ if self.commitment_weight > 0:
318
+ commit_loss = F.mse_loss(quantize.detach(), x)
319
+ loss = loss + commit_loss * self.commitment_weight
320
+
321
+ quantize = self.project_out(quantize)
322
+ quantize = rearrange(quantize, "b n d -> b d n")
323
+ return quantize, embed_ind, loss
324
+
325
+
326
+ class ResidualVectorQuantization(nn.Module):
327
+ """Residual vector quantization implementation.
328
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
329
+ """
330
+
331
+ def __init__(self, *, num_quantizers, **kwargs):
332
+ super().__init__()
333
+ self.layers = nn.ModuleList(
334
+ [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
335
+ )
336
+
337
+ def forward(
338
+ self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
339
+ ):
340
+ quantized_out = 0.0
341
+ residual = x
342
+
343
+ all_losses = []
344
+ all_indices = []
345
+ out_quantized = []
346
+
347
+ n_q = n_q or len(self.layers)
348
+
349
+ for i, layer in enumerate(self.layers[:n_q]):
350
+ quantized, indices, loss = layer(residual)
351
+ residual = residual - quantized
352
+ quantized_out = quantized_out + quantized
353
+
354
+ all_indices.append(indices)
355
+ all_losses.append(loss)
356
+ if layers and i in layers:
357
+ out_quantized.append(quantized)
358
+
359
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
360
+ return quantized_out, out_indices, out_losses, out_quantized
361
+
362
+ def encode(
363
+ self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
364
+ ) -> torch.Tensor:
365
+ residual = x
366
+ all_indices = []
367
+ n_q = n_q or len(self.layers)
368
+ st = st or 0
369
+ for layer in self.layers[st:n_q]:
370
+ indices = layer.encode(residual)
371
+ quantized = layer.decode(indices)
372
+ residual = residual - quantized
373
+ all_indices.append(indices)
374
+ out_indices = torch.stack(all_indices)
375
+ return out_indices
376
+
377
+ def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
378
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
379
+ for i, indices in enumerate(q_indices):
380
+ layer = self.layers[st + i]
381
+ quantized = layer.decode(indices)
382
+ quantized_out = quantized_out + quantized
383
+ return quantized_out
module/data_utils.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time, logging
2
+ import os
3
+ import random, traceback
4
+ import numpy as np
5
+ import torch
6
+ import torch.utils.data
7
+ from tqdm import tqdm
8
+
9
+ from module import commons
10
+ from module.mel_processing import spectrogram_torch
11
+ from text import cleaned_text_to_sequence
12
+ from utils import load_wav_to_torch, load_filepaths_and_text
13
+ import torch.nn.functional as F
14
+ from functools import lru_cache
15
+ import torch
16
+ import requests
17
+ from scipy.io import wavfile
18
+ from io import BytesIO
19
+
20
+ # from config import exp_dir
21
+ from my_utils import load_audio
22
+
23
+
24
+ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
25
+ """
26
+ 1) loads audio, speaker_id, text pairs
27
+ 2) normalizes text and converts them to sequences of integers
28
+ 3) computes spectrograms from audio files.
29
+ """
30
+
31
+ def __init__(self, hparams, val=False):
32
+ exp_dir = hparams.exp_dir
33
+ self.path2 = "%s/2-name2text.txt" % exp_dir
34
+ self.path4 = "%s/4-cnhubert" % exp_dir
35
+ self.path5 = "%s/5-wav32k" % exp_dir
36
+ assert os.path.exists(self.path2)
37
+ assert os.path.exists(self.path4)
38
+ assert os.path.exists(self.path5)
39
+ names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
40
+ names5 = set(os.listdir(self.path5))
41
+ self.phoneme_data = {}
42
+ with open(self.path2, "r", encoding="utf8") as f:
43
+ lines = f.read().strip("\n").split("\n")
44
+
45
+ for line in lines:
46
+ tmp = line.split("\t")
47
+ if len(tmp) != 4:
48
+ continue
49
+ self.phoneme_data[tmp[0]] = [tmp[1]]
50
+
51
+ self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
52
+ tmp = self.audiopaths_sid_text
53
+ leng = len(tmp)
54
+ min_num = 100
55
+ if leng < min_num:
56
+ self.audiopaths_sid_text = []
57
+ for _ in range(max(2, int(min_num / leng))):
58
+ self.audiopaths_sid_text += tmp
59
+ self.max_wav_value = hparams.max_wav_value
60
+ self.sampling_rate = hparams.sampling_rate
61
+ self.filter_length = hparams.filter_length
62
+ self.hop_length = hparams.hop_length
63
+ self.win_length = hparams.win_length
64
+ self.sampling_rate = hparams.sampling_rate
65
+ self.val = val
66
+
67
+ random.seed(1234)
68
+ random.shuffle(self.audiopaths_sid_text)
69
+
70
+ print("phoneme_data_len:", len(self.phoneme_data.keys()))
71
+ print("wav_data_len:", len(self.audiopaths_sid_text))
72
+
73
+ audiopaths_sid_text_new = []
74
+ lengths = []
75
+ skipped_phone = 0
76
+ skipped_dur = 0
77
+ for audiopath in tqdm(self.audiopaths_sid_text):
78
+ try:
79
+ phoneme = self.phoneme_data[audiopath][0]
80
+ phoneme = phoneme.split(" ")
81
+ phoneme_ids = cleaned_text_to_sequence(phoneme)
82
+ except Exception:
83
+ print(f"{audiopath} not in self.phoneme_data !")
84
+ skipped_phone += 1
85
+ continue
86
+ size = os.path.getsize("%s/%s" % (self.path5, audiopath))
87
+ duration = size / self.sampling_rate / 2
88
+ if 54 > duration > 0.6 or self.val:
89
+ audiopaths_sid_text_new.append([audiopath, phoneme_ids])
90
+ lengths.append(size // (2 * self.hop_length))
91
+ else:
92
+ skipped_dur += 1
93
+ continue
94
+ print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
95
+ print("total left: ", len(audiopaths_sid_text_new))
96
+ assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
97
+ self.audiopaths_sid_text = audiopaths_sid_text_new
98
+ self.lengths = lengths
99
+
100
+ def get_audio_text_speaker_pair(self, audiopath_sid_text):
101
+ audiopath, phoneme_ids = audiopath_sid_text
102
+ text = torch.FloatTensor(phoneme_ids)
103
+ try:
104
+ spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
105
+ with torch.no_grad():
106
+ ssl = torch.load(
107
+ "%s/%s.pt" % (self.path4, audiopath), map_location="cpu"
108
+ )
109
+ if ssl.shape[-1] != spec.shape[-1]:
110
+ typee = ssl.dtype
111
+ ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
112
+ ssl.requires_grad = False
113
+ except:
114
+ traceback.print_exc()
115
+ spec = torch.zeros(1025, 100)
116
+ wav = torch.zeros(1, 100 * self.hop_length)
117
+ ssl = torch.zeros(1, 768, 100)
118
+ text = text[-1:]
119
+ print("load audio or ssl error!!!!!!", audiopath)
120
+ # print(ssl.requires_grad,spec.requires_grad,wav.requires_grad,text.requires_grad)
121
+ return (ssl, spec, wav, text)
122
+
123
+ def get_audio(self, filename):
124
+ audio_array = load_audio(
125
+ filename, self.sampling_rate
126
+ ) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
127
+ # print(filename,audio_array.max(),audio_array.min(),audio_array.mean())
128
+ audio = torch.FloatTensor(audio_array) # /32768
129
+ audio_norm = audio
130
+ audio_norm = audio_norm.unsqueeze(0)
131
+ spec = spectrogram_torch(
132
+ audio_norm,
133
+ self.filter_length,
134
+ self.sampling_rate,
135
+ self.hop_length,
136
+ self.win_length,
137
+ center=False,
138
+ )
139
+ spec = torch.squeeze(spec, 0)
140
+ return spec, audio_norm
141
+
142
+ def get_sid(self, sid):
143
+ sid = torch.LongTensor([int(sid)])
144
+ return sid
145
+
146
+ def __getitem__(self, index):
147
+ # with torch.no_grad():
148
+ return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
149
+
150
+ def __len__(self):
151
+ return len(self.audiopaths_sid_text)
152
+
153
+ def random_slice(self, ssl, wav, mel):
154
+ assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
155
+ "first",
156
+ ssl.shape,
157
+ wav.shape,
158
+ )
159
+
160
+ len_mel = mel.shape[1]
161
+ if self.val:
162
+ reference_mel = mel[:, : len_mel // 3]
163
+ return reference_mel, ssl, wav, mel
164
+ dir = random.randint(0, 1)
165
+ sep_point = random.randint(int(len_mel // 3), int(len_mel // 3 * 2))
166
+
167
+ if dir == 0:
168
+ reference_mel = mel[:, :sep_point]
169
+ ssl = ssl[:, :, sep_point:]
170
+ wav2 = wav[:, sep_point * self.hop_length :]
171
+ mel = mel[:, sep_point:]
172
+ else:
173
+ reference_mel = mel[:, sep_point:]
174
+ ssl = ssl[:, :, :sep_point]
175
+ wav2 = wav[:, : sep_point * self.hop_length]
176
+ mel = mel[:, :sep_point]
177
+
178
+ assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
179
+ ssl.shape,
180
+ wav.shape,
181
+ wav2.shape,
182
+ mel.shape,
183
+ sep_point,
184
+ self.hop_length,
185
+ sep_point * self.hop_length,
186
+ dir,
187
+ )
188
+ return reference_mel, ssl, wav2, mel
189
+
190
+
191
+ class TextAudioSpeakerCollate:
192
+ """Zero-pads model inputs and targets"""
193
+
194
+ def __init__(self, return_ids=False):
195
+ self.return_ids = return_ids
196
+
197
+ def __call__(self, batch):
198
+ """Collate's training batch from normalized text, audio and speaker identities
199
+ PARAMS
200
+ ------
201
+ batch: [text_normalized, spec_normalized, wav_normalized, sid]
202
+ """
203
+ # Right zero-pad all one-hot text sequences to max input length
204
+ _, ids_sorted_decreasing = torch.sort(
205
+ torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True
206
+ )
207
+
208
+ max_ssl_len = max([x[0].size(2) for x in batch])
209
+ max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
210
+ max_spec_len = max([x[1].size(1) for x in batch])
211
+ max_spec_len = int(2 * ((max_spec_len // 2) + 1))
212
+ max_wav_len = max([x[2].size(1) for x in batch])
213
+ max_text_len = max([x[3].size(0) for x in batch])
214
+
215
+ ssl_lengths = torch.LongTensor(len(batch))
216
+ spec_lengths = torch.LongTensor(len(batch))
217
+ wav_lengths = torch.LongTensor(len(batch))
218
+ text_lengths = torch.LongTensor(len(batch))
219
+
220
+ spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
221
+ wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
222
+ ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
223
+ text_padded = torch.LongTensor(len(batch), max_text_len)
224
+
225
+ spec_padded.zero_()
226
+ wav_padded.zero_()
227
+ ssl_padded.zero_()
228
+ text_padded.zero_()
229
+
230
+ for i in range(len(ids_sorted_decreasing)):
231
+ row = batch[ids_sorted_decreasing[i]]
232
+
233
+ ssl = row[0]
234
+ ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
235
+ ssl_lengths[i] = ssl.size(2)
236
+
237
+ spec = row[1]
238
+ spec_padded[i, :, : spec.size(1)] = spec
239
+ spec_lengths[i] = spec.size(1)
240
+
241
+ wav = row[2]
242
+ wav_padded[i, :, : wav.size(1)] = wav
243
+ wav_lengths[i] = wav.size(1)
244
+
245
+ text = row[3]
246
+ text_padded[i, : text.size(0)] = text
247
+ text_lengths[i] = text.size(0)
248
+
249
+ return (
250
+ ssl_padded,
251
+ ssl_lengths,
252
+ spec_padded,
253
+ spec_lengths,
254
+ wav_padded,
255
+ wav_lengths,
256
+ text_padded,
257
+ text_lengths,
258
+ )
259
+
260
+
261
+ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
262
+ """
263
+ Maintain similar input lengths in a batch.
264
+ Length groups are specified by boundaries.
265
+ Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
266
+
267
+ It removes samples which are not included in the boundaries.
268
+ Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
269
+ """
270
+
271
+ def __init__(
272
+ self,
273
+ dataset,
274
+ batch_size,
275
+ boundaries,
276
+ num_replicas=None,
277
+ rank=None,
278
+ shuffle=True,
279
+ ):
280
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
281
+ self.lengths = dataset.lengths
282
+ # print(233333333333333,self.lengths,dir(dataset))
283
+ self.batch_size = batch_size
284
+ self.boundaries = boundaries
285
+
286
+ self.buckets, self.num_samples_per_bucket = self._create_buckets()
287
+ self.total_size = sum(self.num_samples_per_bucket)
288
+ self.num_samples = self.total_size // self.num_replicas
289
+
290
+ def _create_buckets(self):
291
+ buckets = [[] for _ in range(len(self.boundaries) - 1)]
292
+ for i in range(len(self.lengths)):
293
+ length = self.lengths[i]
294
+ idx_bucket = self._bisect(length)
295
+ if idx_bucket != -1:
296
+ buckets[idx_bucket].append(i)
297
+
298
+ for i in range(len(buckets) - 1, 0, -1):
299
+ # for i in range(len(buckets) - 1, -1, -1):
300
+ if len(buckets[i]) == 0:
301
+ buckets.pop(i)
302
+ self.boundaries.pop(i + 1)
303
+
304
+ num_samples_per_bucket = []
305
+ for i in range(len(buckets)):
306
+ len_bucket = len(buckets[i])
307
+ total_batch_size = self.num_replicas * self.batch_size
308
+ rem = (
309
+ total_batch_size - (len_bucket % total_batch_size)
310
+ ) % total_batch_size
311
+ num_samples_per_bucket.append(len_bucket + rem)
312
+ return buckets, num_samples_per_bucket
313
+
314
+ def __iter__(self):
315
+ # deterministically shuffle based on epoch
316
+ g = torch.Generator()
317
+ g.manual_seed(self.epoch)
318
+
319
+ indices = []
320
+ if self.shuffle:
321
+ for bucket in self.buckets:
322
+ indices.append(torch.randperm(len(bucket), generator=g).tolist())
323
+ else:
324
+ for bucket in self.buckets:
325
+ indices.append(list(range(len(bucket))))
326
+
327
+ batches = []
328
+ for i in range(len(self.buckets)):
329
+ bucket = self.buckets[i]
330
+ len_bucket = len(bucket)
331
+ ids_bucket = indices[i]
332
+ num_samples_bucket = self.num_samples_per_bucket[i]
333
+
334
+ # add extra samples to make it evenly divisible
335
+ rem = num_samples_bucket - len_bucket
336
+ ids_bucket = (
337
+ ids_bucket
338
+ + ids_bucket * (rem // len_bucket)
339
+ + ids_bucket[: (rem % len_bucket)]
340
+ )
341
+
342
+ # subsample
343
+ ids_bucket = ids_bucket[self.rank :: self.num_replicas]
344
+
345
+ # batching
346
+ for j in range(len(ids_bucket) // self.batch_size):
347
+ batch = [
348
+ bucket[idx]
349
+ for idx in ids_bucket[
350
+ j * self.batch_size : (j + 1) * self.batch_size
351
+ ]
352
+ ]
353
+ batches.append(batch)
354
+
355
+ if self.shuffle:
356
+ batch_ids = torch.randperm(len(batches), generator=g).tolist()
357
+ batches = [batches[i] for i in batch_ids]
358
+ self.batches = batches
359
+
360
+ assert len(self.batches) * self.batch_size == self.num_samples
361
+ return iter(self.batches)
362
+
363
+ def _bisect(self, x, lo=0, hi=None):
364
+ if hi is None:
365
+ hi = len(self.boundaries) - 1
366
+
367
+ if hi > lo:
368
+ mid = (hi + lo) // 2
369
+ if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
370
+ return mid
371
+ elif x <= self.boundaries[mid]:
372
+ return self._bisect(x, lo, mid)
373
+ else:
374
+ return self._bisect(x, mid + 1, hi)
375
+ else:
376
+ return -1
377
+
378
+ def __len__(self):
379
+ return self.num_samples // self.batch_size
module/losses.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+
7
+ def feature_loss(fmap_r, fmap_g):
8
+ loss = 0
9
+ for dr, dg in zip(fmap_r, fmap_g):
10
+ for rl, gl in zip(dr, dg):
11
+ rl = rl.float().detach()
12
+ gl = gl.float()
13
+ loss += torch.mean(torch.abs(rl - gl))
14
+
15
+ return loss * 2
16
+
17
+
18
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
19
+ loss = 0
20
+ r_losses = []
21
+ g_losses = []
22
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
23
+ dr = dr.float()
24
+ dg = dg.float()
25
+ r_loss = torch.mean((1 - dr) ** 2)
26
+ g_loss = torch.mean(dg**2)
27
+ loss += r_loss + g_loss
28
+ r_losses.append(r_loss.item())
29
+ g_losses.append(g_loss.item())
30
+
31
+ return loss, r_losses, g_losses
32
+
33
+
34
+ def generator_loss(disc_outputs):
35
+ loss = 0
36
+ gen_losses = []
37
+ for dg in disc_outputs:
38
+ dg = dg.float()
39
+ l = torch.mean((1 - dg) ** 2)
40
+ gen_losses.append(l)
41
+ loss += l
42
+
43
+ return loss, gen_losses
44
+
45
+
46
+ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
47
+ """
48
+ z_p, logs_q: [b, h, t_t]
49
+ m_p, logs_p: [b, h, t_t]
50
+ """
51
+ z_p = z_p.float()
52
+ logs_q = logs_q.float()
53
+ m_p = m_p.float()
54
+ logs_p = logs_p.float()
55
+ z_mask = z_mask.float()
56
+
57
+ kl = logs_p - logs_q - 0.5
58
+ kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
59
+ kl = torch.sum(kl * z_mask)
60
+ l = kl / torch.sum(z_mask)
61
+ return l
62
+
63
+
64
+ def mle_loss(z, m, logs, logdet, mask):
65
+ l = torch.sum(logs) + 0.5 * torch.sum(
66
+ torch.exp(-2 * logs) * ((z - m) ** 2)
67
+ ) # neg normal likelihood w/o the constant term
68
+ l = l - torch.sum(logdet) # log jacobian determinant
69
+ l = l / torch.sum(
70
+ torch.ones_like(z) * mask
71
+ ) # averaging across batch, channel and time axes
72
+ l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
73
+ return l
module/mel_processing.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.data
8
+ import numpy as np
9
+ import librosa
10
+ import librosa.util as librosa_util
11
+ from librosa.util import normalize, pad_center, tiny
12
+ from scipy.signal import get_window
13
+ from scipy.io.wavfile import read
14
+ from librosa.filters import mel as librosa_mel_fn
15
+
16
+ MAX_WAV_VALUE = 32768.0
17
+
18
+
19
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
20
+ """
21
+ PARAMS
22
+ ------
23
+ C: compression factor
24
+ """
25
+ return torch.log(torch.clamp(x, min=clip_val) * C)
26
+
27
+
28
+ def dynamic_range_decompression_torch(x, C=1):
29
+ """
30
+ PARAMS
31
+ ------
32
+ C: compression factor used to compress
33
+ """
34
+ return torch.exp(x) / C
35
+
36
+
37
+ def spectral_normalize_torch(magnitudes):
38
+ output = dynamic_range_compression_torch(magnitudes)
39
+ return output
40
+
41
+
42
+ def spectral_de_normalize_torch(magnitudes):
43
+ output = dynamic_range_decompression_torch(magnitudes)
44
+ return output
45
+
46
+
47
+ mel_basis = {}
48
+ hann_window = {}
49
+
50
+
51
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
52
+ if torch.min(y) < -1.0:
53
+ print("min value is ", torch.min(y))
54
+ if torch.max(y) > 1.0:
55
+ print("max value is ", torch.max(y))
56
+
57
+ global hann_window
58
+ dtype_device = str(y.dtype) + "_" + str(y.device)
59
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
60
+ if wnsize_dtype_device not in hann_window:
61
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
62
+ dtype=y.dtype, device=y.device
63
+ )
64
+
65
+ y = torch.nn.functional.pad(
66
+ y.unsqueeze(1),
67
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
68
+ mode="reflect",
69
+ )
70
+ y = y.squeeze(1)
71
+ spec = torch.stft(
72
+ y,
73
+ n_fft,
74
+ hop_length=hop_size,
75
+ win_length=win_size,
76
+ window=hann_window[wnsize_dtype_device],
77
+ center=center,
78
+ pad_mode="reflect",
79
+ normalized=False,
80
+ onesided=True,
81
+ return_complex=False,
82
+ )
83
+
84
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
85
+ return spec
86
+
87
+
88
+ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
89
+ global mel_basis
90
+ dtype_device = str(spec.dtype) + "_" + str(spec.device)
91
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
92
+ if fmax_dtype_device not in mel_basis:
93
+ mel = librosa_mel_fn(
94
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
95
+ )
96
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
97
+ dtype=spec.dtype, device=spec.device
98
+ )
99
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
100
+ spec = spectral_normalize_torch(spec)
101
+ return spec
102
+
103
+
104
+ def mel_spectrogram_torch(
105
+ y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
106
+ ):
107
+ if torch.min(y) < -1.0:
108
+ print("min value is ", torch.min(y))
109
+ if torch.max(y) > 1.0:
110
+ print("max value is ", torch.max(y))
111
+
112
+ global mel_basis, hann_window
113
+ dtype_device = str(y.dtype) + "_" + str(y.device)
114
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
115
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
116
+ if fmax_dtype_device not in mel_basis:
117
+ mel = librosa_mel_fn(
118
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
119
+ )
120
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
121
+ dtype=y.dtype, device=y.device
122
+ )
123
+ if wnsize_dtype_device not in hann_window:
124
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
125
+ dtype=y.dtype, device=y.device
126
+ )
127
+
128
+ y = torch.nn.functional.pad(
129
+ y.unsqueeze(1),
130
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
131
+ mode="reflect",
132
+ )
133
+ y = y.squeeze(1)
134
+
135
+ spec = torch.stft(
136
+ y,
137
+ n_fft,
138
+ hop_length=hop_size,
139
+ win_length=win_size,
140
+ window=hann_window[wnsize_dtype_device],
141
+ center=center,
142
+ pad_mode="reflect",
143
+ normalized=False,
144
+ onesided=True,
145
+ return_complex=False,
146
+ )
147
+
148
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
149
+
150
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
151
+ spec = spectral_normalize_torch(spec)
152
+
153
+ return spec
module/models.py ADDED
@@ -0,0 +1,989 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from module import commons
8
+ from module import modules
9
+ from module import attentions
10
+
11
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
12
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
+ from module.commons import init_weights, get_padding
14
+ from module.mrte_model import MRTE
15
+ from module.quantize import ResidualVectorQuantizer
16
+ from text import symbols
17
+ from torch.cuda.amp import autocast
18
+
19
+
20
+ class StochasticDurationPredictor(nn.Module):
21
+ def __init__(
22
+ self,
23
+ in_channels,
24
+ filter_channels,
25
+ kernel_size,
26
+ p_dropout,
27
+ n_flows=4,
28
+ gin_channels=0,
29
+ ):
30
+ super().__init__()
31
+ filter_channels = in_channels # it needs to be removed from future version.
32
+ self.in_channels = in_channels
33
+ self.filter_channels = filter_channels
34
+ self.kernel_size = kernel_size
35
+ self.p_dropout = p_dropout
36
+ self.n_flows = n_flows
37
+ self.gin_channels = gin_channels
38
+
39
+ self.log_flow = modules.Log()
40
+ self.flows = nn.ModuleList()
41
+ self.flows.append(modules.ElementwiseAffine(2))
42
+ for i in range(n_flows):
43
+ self.flows.append(
44
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
45
+ )
46
+ self.flows.append(modules.Flip())
47
+
48
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
49
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
50
+ self.post_convs = modules.DDSConv(
51
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
52
+ )
53
+ self.post_flows = nn.ModuleList()
54
+ self.post_flows.append(modules.ElementwiseAffine(2))
55
+ for i in range(4):
56
+ self.post_flows.append(
57
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
58
+ )
59
+ self.post_flows.append(modules.Flip())
60
+
61
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
62
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
63
+ self.convs = modules.DDSConv(
64
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
65
+ )
66
+ if gin_channels != 0:
67
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
68
+
69
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
70
+ x = torch.detach(x)
71
+ x = self.pre(x)
72
+ if g is not None:
73
+ g = torch.detach(g)
74
+ x = x + self.cond(g)
75
+ x = self.convs(x, x_mask)
76
+ x = self.proj(x) * x_mask
77
+
78
+ if not reverse:
79
+ flows = self.flows
80
+ assert w is not None
81
+
82
+ logdet_tot_q = 0
83
+ h_w = self.post_pre(w)
84
+ h_w = self.post_convs(h_w, x_mask)
85
+ h_w = self.post_proj(h_w) * x_mask
86
+ e_q = (
87
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
88
+ * x_mask
89
+ )
90
+ z_q = e_q
91
+ for flow in self.post_flows:
92
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
93
+ logdet_tot_q += logdet_q
94
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
95
+ u = torch.sigmoid(z_u) * x_mask
96
+ z0 = (w - u) * x_mask
97
+ logdet_tot_q += torch.sum(
98
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
99
+ )
100
+ logq = (
101
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
102
+ - logdet_tot_q
103
+ )
104
+
105
+ logdet_tot = 0
106
+ z0, logdet = self.log_flow(z0, x_mask)
107
+ logdet_tot += logdet
108
+ z = torch.cat([z0, z1], 1)
109
+ for flow in flows:
110
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
111
+ logdet_tot = logdet_tot + logdet
112
+ nll = (
113
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
114
+ - logdet_tot
115
+ )
116
+ return nll + logq # [b]
117
+ else:
118
+ flows = list(reversed(self.flows))
119
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
120
+ z = (
121
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
122
+ * noise_scale
123
+ )
124
+ for flow in flows:
125
+ z = flow(z, x_mask, g=x, reverse=reverse)
126
+ z0, z1 = torch.split(z, [1, 1], 1)
127
+ logw = z0
128
+ return logw
129
+
130
+
131
+ class DurationPredictor(nn.Module):
132
+ def __init__(
133
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
134
+ ):
135
+ super().__init__()
136
+
137
+ self.in_channels = in_channels
138
+ self.filter_channels = filter_channels
139
+ self.kernel_size = kernel_size
140
+ self.p_dropout = p_dropout
141
+ self.gin_channels = gin_channels
142
+
143
+ self.drop = nn.Dropout(p_dropout)
144
+ self.conv_1 = nn.Conv1d(
145
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
146
+ )
147
+ self.norm_1 = modules.LayerNorm(filter_channels)
148
+ self.conv_2 = nn.Conv1d(
149
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
150
+ )
151
+ self.norm_2 = modules.LayerNorm(filter_channels)
152
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
153
+
154
+ if gin_channels != 0:
155
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
156
+
157
+ def forward(self, x, x_mask, g=None):
158
+ x = torch.detach(x)
159
+ if g is not None:
160
+ g = torch.detach(g)
161
+ x = x + self.cond(g)
162
+ x = self.conv_1(x * x_mask)
163
+ x = torch.relu(x)
164
+ x = self.norm_1(x)
165
+ x = self.drop(x)
166
+ x = self.conv_2(x * x_mask)
167
+ x = torch.relu(x)
168
+ x = self.norm_2(x)
169
+ x = self.drop(x)
170
+ x = self.proj(x * x_mask)
171
+ return x * x_mask
172
+
173
+
174
+ class TextEncoder(nn.Module):
175
+ def __init__(
176
+ self,
177
+ out_channels,
178
+ hidden_channels,
179
+ filter_channels,
180
+ n_heads,
181
+ n_layers,
182
+ kernel_size,
183
+ p_dropout,
184
+ latent_channels=192,
185
+ ):
186
+ super().__init__()
187
+ self.out_channels = out_channels
188
+ self.hidden_channels = hidden_channels
189
+ self.filter_channels = filter_channels
190
+ self.n_heads = n_heads
191
+ self.n_layers = n_layers
192
+ self.kernel_size = kernel_size
193
+ self.p_dropout = p_dropout
194
+ self.latent_channels = latent_channels
195
+
196
+ self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
197
+
198
+ self.encoder_ssl = attentions.Encoder(
199
+ hidden_channels,
200
+ filter_channels,
201
+ n_heads,
202
+ n_layers // 2,
203
+ kernel_size,
204
+ p_dropout,
205
+ )
206
+
207
+ self.encoder_text = attentions.Encoder(
208
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
209
+ )
210
+ self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
211
+
212
+ self.mrte = MRTE()
213
+
214
+ self.encoder2 = attentions.Encoder(
215
+ hidden_channels,
216
+ filter_channels,
217
+ n_heads,
218
+ n_layers // 2,
219
+ kernel_size,
220
+ p_dropout,
221
+ )
222
+
223
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
224
+
225
+ def forward(self, y, y_lengths, text, text_lengths, ge, test=None):
226
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
227
+ y.dtype
228
+ )
229
+
230
+ y = self.ssl_proj(y * y_mask) * y_mask
231
+ y = self.encoder_ssl(y * y_mask, y_mask)
232
+
233
+ text_mask = torch.unsqueeze(
234
+ commons.sequence_mask(text_lengths, text.size(1)), 1
235
+ ).to(y.dtype)
236
+ if test == 1:
237
+ text[:, :] = 0
238
+ text = self.text_embedding(text).transpose(1, 2)
239
+ text = self.encoder_text(text * text_mask, text_mask)
240
+ y = self.mrte(y, y_mask, text, text_mask, ge)
241
+
242
+ y = self.encoder2(y * y_mask, y_mask)
243
+
244
+ stats = self.proj(y) * y_mask
245
+ m, logs = torch.split(stats, self.out_channels, dim=1)
246
+ return y, m, logs, y_mask
247
+
248
+ def extract_latent(self, x):
249
+ x = self.ssl_proj(x)
250
+ quantized, codes, commit_loss, quantized_list = self.quantizer(x)
251
+ return codes.transpose(0, 1)
252
+
253
+ def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
254
+ quantized = self.quantizer.decode(codes)
255
+
256
+ y = self.vq_proj(quantized) * y_mask
257
+ y = self.encoder_ssl(y * y_mask, y_mask)
258
+
259
+ y = self.mrte(y, y_mask, refer, refer_mask, ge)
260
+
261
+ y = self.encoder2(y * y_mask, y_mask)
262
+
263
+ stats = self.proj(y) * y_mask
264
+ m, logs = torch.split(stats, self.out_channels, dim=1)
265
+ return y, m, logs, y_mask, quantized
266
+
267
+
268
+ class ResidualCouplingBlock(nn.Module):
269
+ def __init__(
270
+ self,
271
+ channels,
272
+ hidden_channels,
273
+ kernel_size,
274
+ dilation_rate,
275
+ n_layers,
276
+ n_flows=4,
277
+ gin_channels=0,
278
+ ):
279
+ super().__init__()
280
+ self.channels = channels
281
+ self.hidden_channels = hidden_channels
282
+ self.kernel_size = kernel_size
283
+ self.dilation_rate = dilation_rate
284
+ self.n_layers = n_layers
285
+ self.n_flows = n_flows
286
+ self.gin_channels = gin_channels
287
+
288
+ self.flows = nn.ModuleList()
289
+ for i in range(n_flows):
290
+ self.flows.append(
291
+ modules.ResidualCouplingLayer(
292
+ channels,
293
+ hidden_channels,
294
+ kernel_size,
295
+ dilation_rate,
296
+ n_layers,
297
+ gin_channels=gin_channels,
298
+ mean_only=True,
299
+ )
300
+ )
301
+ self.flows.append(modules.Flip())
302
+
303
+ def forward(self, x, x_mask, g=None, reverse=False):
304
+ if not reverse:
305
+ for flow in self.flows:
306
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
307
+ else:
308
+ for flow in reversed(self.flows):
309
+ x = flow(x, x_mask, g=g, reverse=reverse)
310
+ return x
311
+
312
+
313
+ class PosteriorEncoder(nn.Module):
314
+ def __init__(
315
+ self,
316
+ in_channels,
317
+ out_channels,
318
+ hidden_channels,
319
+ kernel_size,
320
+ dilation_rate,
321
+ n_layers,
322
+ gin_channels=0,
323
+ ):
324
+ super().__init__()
325
+ self.in_channels = in_channels
326
+ self.out_channels = out_channels
327
+ self.hidden_channels = hidden_channels
328
+ self.kernel_size = kernel_size
329
+ self.dilation_rate = dilation_rate
330
+ self.n_layers = n_layers
331
+ self.gin_channels = gin_channels
332
+
333
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
334
+ self.enc = modules.WN(
335
+ hidden_channels,
336
+ kernel_size,
337
+ dilation_rate,
338
+ n_layers,
339
+ gin_channels=gin_channels,
340
+ )
341
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
342
+
343
+ def forward(self, x, x_lengths, g=None):
344
+ if g != None:
345
+ g = g.detach()
346
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
347
+ x.dtype
348
+ )
349
+ x = self.pre(x) * x_mask
350
+ x = self.enc(x, x_mask, g=g)
351
+ stats = self.proj(x) * x_mask
352
+ m, logs = torch.split(stats, self.out_channels, dim=1)
353
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
354
+ return z, m, logs, x_mask
355
+
356
+
357
+ class WNEncoder(nn.Module):
358
+ def __init__(
359
+ self,
360
+ in_channels,
361
+ out_channels,
362
+ hidden_channels,
363
+ kernel_size,
364
+ dilation_rate,
365
+ n_layers,
366
+ gin_channels=0,
367
+ ):
368
+ super().__init__()
369
+ self.in_channels = in_channels
370
+ self.out_channels = out_channels
371
+ self.hidden_channels = hidden_channels
372
+ self.kernel_size = kernel_size
373
+ self.dilation_rate = dilation_rate
374
+ self.n_layers = n_layers
375
+ self.gin_channels = gin_channels
376
+
377
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
378
+ self.enc = modules.WN(
379
+ hidden_channels,
380
+ kernel_size,
381
+ dilation_rate,
382
+ n_layers,
383
+ gin_channels=gin_channels,
384
+ )
385
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
386
+ self.norm = modules.LayerNorm(out_channels)
387
+
388
+ def forward(self, x, x_lengths, g=None):
389
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
390
+ x.dtype
391
+ )
392
+ x = self.pre(x) * x_mask
393
+ x = self.enc(x, x_mask, g=g)
394
+ out = self.proj(x) * x_mask
395
+ out = self.norm(out)
396
+ return out
397
+
398
+
399
+ class Generator(torch.nn.Module):
400
+ def __init__(
401
+ self,
402
+ initial_channel,
403
+ resblock,
404
+ resblock_kernel_sizes,
405
+ resblock_dilation_sizes,
406
+ upsample_rates,
407
+ upsample_initial_channel,
408
+ upsample_kernel_sizes,
409
+ gin_channels=0,
410
+ ):
411
+ super(Generator, self).__init__()
412
+ self.num_kernels = len(resblock_kernel_sizes)
413
+ self.num_upsamples = len(upsample_rates)
414
+ self.conv_pre = Conv1d(
415
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
416
+ )
417
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
418
+
419
+ self.ups = nn.ModuleList()
420
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
421
+ self.ups.append(
422
+ weight_norm(
423
+ ConvTranspose1d(
424
+ upsample_initial_channel // (2**i),
425
+ upsample_initial_channel // (2 ** (i + 1)),
426
+ k,
427
+ u,
428
+ padding=(k - u) // 2,
429
+ )
430
+ )
431
+ )
432
+
433
+ self.resblocks = nn.ModuleList()
434
+ for i in range(len(self.ups)):
435
+ ch = upsample_initial_channel // (2 ** (i + 1))
436
+ for j, (k, d) in enumerate(
437
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
438
+ ):
439
+ self.resblocks.append(resblock(ch, k, d))
440
+
441
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
442
+ self.ups.apply(init_weights)
443
+
444
+ if gin_channels != 0:
445
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
446
+
447
+ def forward(self, x, g=None):
448
+ x = self.conv_pre(x)
449
+ if g is not None:
450
+ x = x + self.cond(g)
451
+
452
+ for i in range(self.num_upsamples):
453
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
454
+ x = self.ups[i](x)
455
+ xs = None
456
+ for j in range(self.num_kernels):
457
+ if xs is None:
458
+ xs = self.resblocks[i * self.num_kernels + j](x)
459
+ else:
460
+ xs += self.resblocks[i * self.num_kernels + j](x)
461
+ x = xs / self.num_kernels
462
+ x = F.leaky_relu(x)
463
+ x = self.conv_post(x)
464
+ x = torch.tanh(x)
465
+
466
+ return x
467
+
468
+ def remove_weight_norm(self):
469
+ print("Removing weight norm...")
470
+ for l in self.ups:
471
+ remove_weight_norm(l)
472
+ for l in self.resblocks:
473
+ l.remove_weight_norm()
474
+
475
+
476
+ class DiscriminatorP(torch.nn.Module):
477
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
478
+ super(DiscriminatorP, self).__init__()
479
+ self.period = period
480
+ self.use_spectral_norm = use_spectral_norm
481
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
482
+ self.convs = nn.ModuleList(
483
+ [
484
+ norm_f(
485
+ Conv2d(
486
+ 1,
487
+ 32,
488
+ (kernel_size, 1),
489
+ (stride, 1),
490
+ padding=(get_padding(kernel_size, 1), 0),
491
+ )
492
+ ),
493
+ norm_f(
494
+ Conv2d(
495
+ 32,
496
+ 128,
497
+ (kernel_size, 1),
498
+ (stride, 1),
499
+ padding=(get_padding(kernel_size, 1), 0),
500
+ )
501
+ ),
502
+ norm_f(
503
+ Conv2d(
504
+ 128,
505
+ 512,
506
+ (kernel_size, 1),
507
+ (stride, 1),
508
+ padding=(get_padding(kernel_size, 1), 0),
509
+ )
510
+ ),
511
+ norm_f(
512
+ Conv2d(
513
+ 512,
514
+ 1024,
515
+ (kernel_size, 1),
516
+ (stride, 1),
517
+ padding=(get_padding(kernel_size, 1), 0),
518
+ )
519
+ ),
520
+ norm_f(
521
+ Conv2d(
522
+ 1024,
523
+ 1024,
524
+ (kernel_size, 1),
525
+ 1,
526
+ padding=(get_padding(kernel_size, 1), 0),
527
+ )
528
+ ),
529
+ ]
530
+ )
531
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
532
+
533
+ def forward(self, x):
534
+ fmap = []
535
+
536
+ # 1d to 2d
537
+ b, c, t = x.shape
538
+ if t % self.period != 0: # pad first
539
+ n_pad = self.period - (t % self.period)
540
+ x = F.pad(x, (0, n_pad), "reflect")
541
+ t = t + n_pad
542
+ x = x.view(b, c, t // self.period, self.period)
543
+
544
+ for l in self.convs:
545
+ x = l(x)
546
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
547
+ fmap.append(x)
548
+ x = self.conv_post(x)
549
+ fmap.append(x)
550
+ x = torch.flatten(x, 1, -1)
551
+
552
+ return x, fmap
553
+
554
+
555
+ class DiscriminatorS(torch.nn.Module):
556
+ def __init__(self, use_spectral_norm=False):
557
+ super(DiscriminatorS, self).__init__()
558
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
559
+ self.convs = nn.ModuleList(
560
+ [
561
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
562
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
563
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
564
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
565
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
566
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
567
+ ]
568
+ )
569
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
570
+
571
+ def forward(self, x):
572
+ fmap = []
573
+
574
+ for l in self.convs:
575
+ x = l(x)
576
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
577
+ fmap.append(x)
578
+ x = self.conv_post(x)
579
+ fmap.append(x)
580
+ x = torch.flatten(x, 1, -1)
581
+
582
+ return x, fmap
583
+
584
+
585
+ class MultiPeriodDiscriminator(torch.nn.Module):
586
+ def __init__(self, use_spectral_norm=False):
587
+ super(MultiPeriodDiscriminator, self).__init__()
588
+ periods = [2, 3, 5, 7, 11]
589
+
590
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
591
+ discs = discs + [
592
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
593
+ ]
594
+ self.discriminators = nn.ModuleList(discs)
595
+
596
+ def forward(self, y, y_hat):
597
+ y_d_rs = []
598
+ y_d_gs = []
599
+ fmap_rs = []
600
+ fmap_gs = []
601
+ for i, d in enumerate(self.discriminators):
602
+ y_d_r, fmap_r = d(y)
603
+ y_d_g, fmap_g = d(y_hat)
604
+ y_d_rs.append(y_d_r)
605
+ y_d_gs.append(y_d_g)
606
+ fmap_rs.append(fmap_r)
607
+ fmap_gs.append(fmap_g)
608
+
609
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
610
+
611
+
612
+ class ReferenceEncoder(nn.Module):
613
+ """
614
+ inputs --- [N, Ty/r, n_mels*r] mels
615
+ outputs --- [N, ref_enc_gru_size]
616
+ """
617
+
618
+ def __init__(self, spec_channels, gin_channels=0):
619
+ super().__init__()
620
+ self.spec_channels = spec_channels
621
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
622
+ K = len(ref_enc_filters)
623
+ filters = [1] + ref_enc_filters
624
+ convs = [
625
+ weight_norm(
626
+ nn.Conv2d(
627
+ in_channels=filters[i],
628
+ out_channels=filters[i + 1],
629
+ kernel_size=(3, 3),
630
+ stride=(2, 2),
631
+ padding=(1, 1),
632
+ )
633
+ )
634
+ for i in range(K)
635
+ ]
636
+ self.convs = nn.ModuleList(convs)
637
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
638
+
639
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
640
+ self.gru = nn.GRU(
641
+ input_size=ref_enc_filters[-1] * out_channels,
642
+ hidden_size=256 // 2,
643
+ batch_first=True,
644
+ )
645
+ self.proj = nn.Linear(128, gin_channels)
646
+
647
+ def forward(self, inputs):
648
+ N = inputs.size(0)
649
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
650
+ for conv in self.convs:
651
+ out = conv(out)
652
+ # out = wn(out)
653
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
654
+
655
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
656
+ T = out.size(1)
657
+ N = out.size(0)
658
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
659
+
660
+ self.gru.flatten_parameters()
661
+ memory, out = self.gru(out) # out --- [1, N, 128]
662
+
663
+ return self.proj(out.squeeze(0)).unsqueeze(-1)
664
+
665
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
666
+ for i in range(n_convs):
667
+ L = (L - kernel_size + 2 * pad) // stride + 1
668
+ return L
669
+
670
+
671
+ class Quantizer_module(torch.nn.Module):
672
+ def __init__(self, n_e, e_dim):
673
+ super(Quantizer_module, self).__init__()
674
+ self.embedding = nn.Embedding(n_e, e_dim)
675
+ self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e)
676
+
677
+ def forward(self, x):
678
+ d = (
679
+ torch.sum(x**2, 1, keepdim=True)
680
+ + torch.sum(self.embedding.weight**2, 1)
681
+ - 2 * torch.matmul(x, self.embedding.weight.T)
682
+ )
683
+ min_indicies = torch.argmin(d, 1)
684
+ z_q = self.embedding(min_indicies)
685
+ return z_q, min_indicies
686
+
687
+
688
+ class Quantizer(torch.nn.Module):
689
+ def __init__(self, embed_dim=512, n_code_groups=4, n_codes=160):
690
+ super(Quantizer, self).__init__()
691
+ assert embed_dim % n_code_groups == 0
692
+ self.quantizer_modules = nn.ModuleList(
693
+ [
694
+ Quantizer_module(n_codes, embed_dim // n_code_groups)
695
+ for _ in range(n_code_groups)
696
+ ]
697
+ )
698
+ self.n_code_groups = n_code_groups
699
+ self.embed_dim = embed_dim
700
+
701
+ def forward(self, xin):
702
+ # B, C, T
703
+ B, C, T = xin.shape
704
+ xin = xin.transpose(1, 2)
705
+ x = xin.reshape(-1, self.embed_dim)
706
+ x = torch.split(x, self.embed_dim // self.n_code_groups, dim=-1)
707
+ min_indicies = []
708
+ z_q = []
709
+ for _x, m in zip(x, self.quantizer_modules):
710
+ _z_q, _min_indicies = m(_x)
711
+ z_q.append(_z_q)
712
+ min_indicies.append(_min_indicies) # B * T,
713
+ z_q = torch.cat(z_q, -1).reshape(xin.shape)
714
+ loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
715
+ (z_q - xin.detach()) ** 2
716
+ )
717
+ z_q = xin + (z_q - xin).detach()
718
+ z_q = z_q.transpose(1, 2)
719
+ codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
720
+ return z_q, loss, codes.transpose(1, 2)
721
+
722
+ def embed(self, x):
723
+ # idx: N, 4, T
724
+ x = x.transpose(1, 2)
725
+ x = torch.split(x, 1, 2)
726
+ ret = []
727
+ for q, embed in zip(x, self.quantizer_modules):
728
+ q = embed.embedding(q.squeeze(-1))
729
+ ret.append(q)
730
+ ret = torch.cat(ret, -1)
731
+ return ret.transpose(1, 2) # N, C, T
732
+
733
+
734
+ class CodePredictor(nn.Module):
735
+ def __init__(
736
+ self,
737
+ hidden_channels,
738
+ filter_channels,
739
+ n_heads,
740
+ n_layers,
741
+ kernel_size,
742
+ p_dropout,
743
+ n_q=8,
744
+ dims=1024,
745
+ ssl_dim=768,
746
+ ):
747
+ super().__init__()
748
+ self.hidden_channels = hidden_channels
749
+ self.filter_channels = filter_channels
750
+ self.n_heads = n_heads
751
+ self.n_layers = n_layers
752
+ self.kernel_size = kernel_size
753
+ self.p_dropout = p_dropout
754
+
755
+ self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
756
+ self.ref_enc = modules.MelStyleEncoder(
757
+ ssl_dim, style_vector_dim=hidden_channels
758
+ )
759
+
760
+ self.encoder = attentions.Encoder(
761
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
762
+ )
763
+
764
+ self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
765
+ self.n_q = n_q
766
+ self.dims = dims
767
+
768
+ def forward(self, x, x_mask, refer, codes, infer=False):
769
+ x = x.detach()
770
+ x = self.vq_proj(x * x_mask) * x_mask
771
+ g = self.ref_enc(refer, x_mask)
772
+ x = x + g
773
+ x = self.encoder(x * x_mask, x_mask)
774
+ x = self.out_proj(x * x_mask) * x_mask
775
+ logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
776
+ 2, 3
777
+ )
778
+ target = codes[1:].transpose(0, 1)
779
+ if not infer:
780
+ logits = logits.reshape(-1, self.dims)
781
+ target = target.reshape(-1)
782
+ loss = torch.nn.functional.cross_entropy(logits, target)
783
+ return loss
784
+ else:
785
+ _, top10_preds = torch.topk(logits, 10, dim=-1)
786
+ correct_top10 = torch.any(top10_preds == target.unsqueeze(-1), dim=-1)
787
+ top3_acc = 100 * torch.mean(correct_top10.float()).detach().cpu().item()
788
+
789
+ print("Top-10 Accuracy:", top3_acc, "%")
790
+
791
+ pred_codes = torch.argmax(logits, dim=-1)
792
+ acc = 100 * torch.mean((pred_codes == target).float()).detach().cpu().item()
793
+ print("Top-1 Accuracy:", acc, "%")
794
+
795
+ return pred_codes.transpose(0, 1)
796
+
797
+
798
+ class SynthesizerTrn(nn.Module):
799
+ """
800
+ Synthesizer for Training
801
+ """
802
+
803
+ def __init__(
804
+ self,
805
+ spec_channels,
806
+ segment_size,
807
+ inter_channels,
808
+ hidden_channels,
809
+ filter_channels,
810
+ n_heads,
811
+ n_layers,
812
+ kernel_size,
813
+ p_dropout,
814
+ resblock,
815
+ resblock_kernel_sizes,
816
+ resblock_dilation_sizes,
817
+ upsample_rates,
818
+ upsample_initial_channel,
819
+ upsample_kernel_sizes,
820
+ n_speakers=0,
821
+ gin_channels=0,
822
+ use_sdp=True,
823
+ semantic_frame_rate=None,
824
+ freeze_quantizer=None,
825
+ **kwargs
826
+ ):
827
+ super().__init__()
828
+ self.spec_channels = spec_channels
829
+ self.inter_channels = inter_channels
830
+ self.hidden_channels = hidden_channels
831
+ self.filter_channels = filter_channels
832
+ self.n_heads = n_heads
833
+ self.n_layers = n_layers
834
+ self.kernel_size = kernel_size
835
+ self.p_dropout = p_dropout
836
+ self.resblock = resblock
837
+ self.resblock_kernel_sizes = resblock_kernel_sizes
838
+ self.resblock_dilation_sizes = resblock_dilation_sizes
839
+ self.upsample_rates = upsample_rates
840
+ self.upsample_initial_channel = upsample_initial_channel
841
+ self.upsample_kernel_sizes = upsample_kernel_sizes
842
+ self.segment_size = segment_size
843
+ self.n_speakers = n_speakers
844
+ self.gin_channels = gin_channels
845
+
846
+ self.use_sdp = use_sdp
847
+ self.enc_p = TextEncoder(
848
+ inter_channels,
849
+ hidden_channels,
850
+ filter_channels,
851
+ n_heads,
852
+ n_layers,
853
+ kernel_size,
854
+ p_dropout,
855
+ )
856
+ self.dec = Generator(
857
+ inter_channels,
858
+ resblock,
859
+ resblock_kernel_sizes,
860
+ resblock_dilation_sizes,
861
+ upsample_rates,
862
+ upsample_initial_channel,
863
+ upsample_kernel_sizes,
864
+ gin_channels=gin_channels,
865
+ )
866
+ self.enc_q = PosteriorEncoder(
867
+ spec_channels,
868
+ inter_channels,
869
+ hidden_channels,
870
+ 5,
871
+ 1,
872
+ 16,
873
+ gin_channels=gin_channels,
874
+ )
875
+ self.flow = ResidualCouplingBlock(
876
+ inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
877
+ )
878
+
879
+ self.ref_enc = modules.MelStyleEncoder(
880
+ spec_channels, style_vector_dim=gin_channels
881
+ )
882
+
883
+ ssl_dim = 768
884
+ assert semantic_frame_rate in ["25hz", "50hz"]
885
+ self.semantic_frame_rate = semantic_frame_rate
886
+ if semantic_frame_rate == "25hz":
887
+ self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
888
+ else:
889
+ self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
890
+
891
+ self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
892
+ if freeze_quantizer:
893
+ self.ssl_proj.requires_grad_(False)
894
+ self.quantizer.requires_grad_(False)
895
+ # self.enc_p.text_embedding.requires_grad_(False)
896
+ # self.enc_p.encoder_text.requires_grad_(False)
897
+ # self.enc_p.mrte.requires_grad_(False)
898
+
899
+ def forward(self, ssl, y, y_lengths, text, text_lengths):
900
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
901
+ y.dtype
902
+ )
903
+ ge = self.ref_enc(y * y_mask, y_mask)
904
+
905
+ with autocast(enabled=False):
906
+ ssl = self.ssl_proj(ssl)
907
+ quantized, codes, commit_loss, quantized_list = self.quantizer(
908
+ ssl, layers=[0]
909
+ )
910
+
911
+ if self.semantic_frame_rate == "25hz":
912
+ quantized = F.interpolate(
913
+ quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
914
+ )
915
+
916
+ x, m_p, logs_p, y_mask = self.enc_p(
917
+ quantized, y_lengths, text, text_lengths, ge
918
+ )
919
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
920
+ z_p = self.flow(z, y_mask, g=ge)
921
+
922
+ z_slice, ids_slice = commons.rand_slice_segments(
923
+ z, y_lengths, self.segment_size
924
+ )
925
+ o = self.dec(z_slice, g=ge)
926
+ return (
927
+ o,
928
+ commit_loss,
929
+ ids_slice,
930
+ y_mask,
931
+ y_mask,
932
+ (z, z_p, m_p, logs_p, m_q, logs_q),
933
+ quantized,
934
+ )
935
+
936
+ def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5):
937
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
938
+ y.dtype
939
+ )
940
+ ge = self.ref_enc(y * y_mask, y_mask)
941
+
942
+ ssl = self.ssl_proj(ssl)
943
+ quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
944
+ if self.semantic_frame_rate == "25hz":
945
+ quantized = F.interpolate(
946
+ quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
947
+ )
948
+
949
+ x, m_p, logs_p, y_mask = self.enc_p(
950
+ quantized, y_lengths, text, text_lengths, ge, test=test
951
+ )
952
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
953
+
954
+ z = self.flow(z_p, y_mask, g=ge, reverse=True)
955
+
956
+ o = self.dec((z * y_mask)[:, :, :], g=ge)
957
+ return o, y_mask, (z, z_p, m_p, logs_p)
958
+
959
+ @torch.no_grad()
960
+ def decode(self, codes, text, refer, noise_scale=0.5):
961
+ refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
962
+ refer_mask = torch.unsqueeze(
963
+ commons.sequence_mask(refer_lengths, refer.size(2)), 1
964
+ ).to(refer.dtype)
965
+ ge = self.ref_enc(refer * refer_mask, refer_mask)
966
+
967
+ y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
968
+ text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
969
+
970
+ quantized = self.quantizer.decode(codes)
971
+ if self.semantic_frame_rate == "25hz":
972
+ quantized = F.interpolate(
973
+ quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
974
+ )
975
+
976
+ x, m_p, logs_p, y_mask = self.enc_p(
977
+ quantized, y_lengths, text, text_lengths, ge
978
+ )
979
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
980
+
981
+ z = self.flow(z_p, y_mask, g=ge, reverse=True)
982
+
983
+ o = self.dec((z * y_mask)[:, :, :], g=ge)
984
+ return o
985
+
986
+ def extract_latent(self, x):
987
+ ssl = self.ssl_proj(x)
988
+ quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
989
+ return codes.transpose(0, 1)
module/modules.py ADDED
@@ -0,0 +1,923 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from torch.nn import Conv1d
8
+ from torch.nn.utils import weight_norm, remove_weight_norm
9
+
10
+ from module import commons
11
+ from module.commons import init_weights, get_padding
12
+ from module.transforms import piecewise_rational_quadratic_transform
13
+ import torch.distributions as D
14
+
15
+
16
+ LRELU_SLOPE = 0.1
17
+
18
+
19
+ class LayerNorm(nn.Module):
20
+ def __init__(self, channels, eps=1e-5):
21
+ super().__init__()
22
+ self.channels = channels
23
+ self.eps = eps
24
+
25
+ self.gamma = nn.Parameter(torch.ones(channels))
26
+ self.beta = nn.Parameter(torch.zeros(channels))
27
+
28
+ def forward(self, x):
29
+ x = x.transpose(1, -1)
30
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
31
+ return x.transpose(1, -1)
32
+
33
+
34
+ class ConvReluNorm(nn.Module):
35
+ def __init__(
36
+ self,
37
+ in_channels,
38
+ hidden_channels,
39
+ out_channels,
40
+ kernel_size,
41
+ n_layers,
42
+ p_dropout,
43
+ ):
44
+ super().__init__()
45
+ self.in_channels = in_channels
46
+ self.hidden_channels = hidden_channels
47
+ self.out_channels = out_channels
48
+ self.kernel_size = kernel_size
49
+ self.n_layers = n_layers
50
+ self.p_dropout = p_dropout
51
+ assert n_layers > 1, "Number of layers should be larger than 0."
52
+
53
+ self.conv_layers = nn.ModuleList()
54
+ self.norm_layers = nn.ModuleList()
55
+ self.conv_layers.append(
56
+ nn.Conv1d(
57
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
58
+ )
59
+ )
60
+ self.norm_layers.append(LayerNorm(hidden_channels))
61
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
62
+ for _ in range(n_layers - 1):
63
+ self.conv_layers.append(
64
+ nn.Conv1d(
65
+ hidden_channels,
66
+ hidden_channels,
67
+ kernel_size,
68
+ padding=kernel_size // 2,
69
+ )
70
+ )
71
+ self.norm_layers.append(LayerNorm(hidden_channels))
72
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
73
+ self.proj.weight.data.zero_()
74
+ self.proj.bias.data.zero_()
75
+
76
+ def forward(self, x, x_mask):
77
+ x_org = x
78
+ for i in range(self.n_layers):
79
+ x = self.conv_layers[i](x * x_mask)
80
+ x = self.norm_layers[i](x)
81
+ x = self.relu_drop(x)
82
+ x = x_org + self.proj(x)
83
+ return x * x_mask
84
+
85
+
86
+ class DDSConv(nn.Module):
87
+ """
88
+ Dialted and Depth-Separable Convolution
89
+ """
90
+
91
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
92
+ super().__init__()
93
+ self.channels = channels
94
+ self.kernel_size = kernel_size
95
+ self.n_layers = n_layers
96
+ self.p_dropout = p_dropout
97
+
98
+ self.drop = nn.Dropout(p_dropout)
99
+ self.convs_sep = nn.ModuleList()
100
+ self.convs_1x1 = nn.ModuleList()
101
+ self.norms_1 = nn.ModuleList()
102
+ self.norms_2 = nn.ModuleList()
103
+ for i in range(n_layers):
104
+ dilation = kernel_size**i
105
+ padding = (kernel_size * dilation - dilation) // 2
106
+ self.convs_sep.append(
107
+ nn.Conv1d(
108
+ channels,
109
+ channels,
110
+ kernel_size,
111
+ groups=channels,
112
+ dilation=dilation,
113
+ padding=padding,
114
+ )
115
+ )
116
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
117
+ self.norms_1.append(LayerNorm(channels))
118
+ self.norms_2.append(LayerNorm(channels))
119
+
120
+ def forward(self, x, x_mask, g=None):
121
+ if g is not None:
122
+ x = x + g
123
+ for i in range(self.n_layers):
124
+ y = self.convs_sep[i](x * x_mask)
125
+ y = self.norms_1[i](y)
126
+ y = F.gelu(y)
127
+ y = self.convs_1x1[i](y)
128
+ y = self.norms_2[i](y)
129
+ y = F.gelu(y)
130
+ y = self.drop(y)
131
+ x = x + y
132
+ return x * x_mask
133
+
134
+
135
+ class WN(torch.nn.Module):
136
+ def __init__(
137
+ self,
138
+ hidden_channels,
139
+ kernel_size,
140
+ dilation_rate,
141
+ n_layers,
142
+ gin_channels=0,
143
+ p_dropout=0,
144
+ ):
145
+ super(WN, self).__init__()
146
+ assert kernel_size % 2 == 1
147
+ self.hidden_channels = hidden_channels
148
+ self.kernel_size = (kernel_size,)
149
+ self.dilation_rate = dilation_rate
150
+ self.n_layers = n_layers
151
+ self.gin_channels = gin_channels
152
+ self.p_dropout = p_dropout
153
+
154
+ self.in_layers = torch.nn.ModuleList()
155
+ self.res_skip_layers = torch.nn.ModuleList()
156
+ self.drop = nn.Dropout(p_dropout)
157
+
158
+ if gin_channels != 0:
159
+ cond_layer = torch.nn.Conv1d(
160
+ gin_channels, 2 * hidden_channels * n_layers, 1
161
+ )
162
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
163
+
164
+ for i in range(n_layers):
165
+ dilation = dilation_rate**i
166
+ padding = int((kernel_size * dilation - dilation) / 2)
167
+ in_layer = torch.nn.Conv1d(
168
+ hidden_channels,
169
+ 2 * hidden_channels,
170
+ kernel_size,
171
+ dilation=dilation,
172
+ padding=padding,
173
+ )
174
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
175
+ self.in_layers.append(in_layer)
176
+
177
+ # last one is not necessary
178
+ if i < n_layers - 1:
179
+ res_skip_channels = 2 * hidden_channels
180
+ else:
181
+ res_skip_channels = hidden_channels
182
+
183
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
184
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
185
+ self.res_skip_layers.append(res_skip_layer)
186
+
187
+ def forward(self, x, x_mask, g=None, **kwargs):
188
+ output = torch.zeros_like(x)
189
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
190
+
191
+ if g is not None:
192
+ g = self.cond_layer(g)
193
+
194
+ for i in range(self.n_layers):
195
+ x_in = self.in_layers[i](x)
196
+ if g is not None:
197
+ cond_offset = i * 2 * self.hidden_channels
198
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
199
+ else:
200
+ g_l = torch.zeros_like(x_in)
201
+
202
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
203
+ acts = self.drop(acts)
204
+
205
+ res_skip_acts = self.res_skip_layers[i](acts)
206
+ if i < self.n_layers - 1:
207
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
208
+ x = (x + res_acts) * x_mask
209
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
210
+ else:
211
+ output = output + res_skip_acts
212
+ return output * x_mask
213
+
214
+ def remove_weight_norm(self):
215
+ if self.gin_channels != 0:
216
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
217
+ for l in self.in_layers:
218
+ torch.nn.utils.remove_weight_norm(l)
219
+ for l in self.res_skip_layers:
220
+ torch.nn.utils.remove_weight_norm(l)
221
+
222
+
223
+ class ResBlock1(torch.nn.Module):
224
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
225
+ super(ResBlock1, self).__init__()
226
+ self.convs1 = nn.ModuleList(
227
+ [
228
+ weight_norm(
229
+ Conv1d(
230
+ channels,
231
+ channels,
232
+ kernel_size,
233
+ 1,
234
+ dilation=dilation[0],
235
+ padding=get_padding(kernel_size, dilation[0]),
236
+ )
237
+ ),
238
+ weight_norm(
239
+ Conv1d(
240
+ channels,
241
+ channels,
242
+ kernel_size,
243
+ 1,
244
+ dilation=dilation[1],
245
+ padding=get_padding(kernel_size, dilation[1]),
246
+ )
247
+ ),
248
+ weight_norm(
249
+ Conv1d(
250
+ channels,
251
+ channels,
252
+ kernel_size,
253
+ 1,
254
+ dilation=dilation[2],
255
+ padding=get_padding(kernel_size, dilation[2]),
256
+ )
257
+ ),
258
+ ]
259
+ )
260
+ self.convs1.apply(init_weights)
261
+
262
+ self.convs2 = nn.ModuleList(
263
+ [
264
+ weight_norm(
265
+ Conv1d(
266
+ channels,
267
+ channels,
268
+ kernel_size,
269
+ 1,
270
+ dilation=1,
271
+ padding=get_padding(kernel_size, 1),
272
+ )
273
+ ),
274
+ weight_norm(
275
+ Conv1d(
276
+ channels,
277
+ channels,
278
+ kernel_size,
279
+ 1,
280
+ dilation=1,
281
+ padding=get_padding(kernel_size, 1),
282
+ )
283
+ ),
284
+ weight_norm(
285
+ Conv1d(
286
+ channels,
287
+ channels,
288
+ kernel_size,
289
+ 1,
290
+ dilation=1,
291
+ padding=get_padding(kernel_size, 1),
292
+ )
293
+ ),
294
+ ]
295
+ )
296
+ self.convs2.apply(init_weights)
297
+
298
+ def forward(self, x, x_mask=None):
299
+ for c1, c2 in zip(self.convs1, self.convs2):
300
+ xt = F.leaky_relu(x, LRELU_SLOPE)
301
+ if x_mask is not None:
302
+ xt = xt * x_mask
303
+ xt = c1(xt)
304
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
305
+ if x_mask is not None:
306
+ xt = xt * x_mask
307
+ xt = c2(xt)
308
+ x = xt + x
309
+ if x_mask is not None:
310
+ x = x * x_mask
311
+ return x
312
+
313
+ def remove_weight_norm(self):
314
+ for l in self.convs1:
315
+ remove_weight_norm(l)
316
+ for l in self.convs2:
317
+ remove_weight_norm(l)
318
+
319
+
320
+ class ResBlock2(torch.nn.Module):
321
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
322
+ super(ResBlock2, self).__init__()
323
+ self.convs = nn.ModuleList(
324
+ [
325
+ weight_norm(
326
+ Conv1d(
327
+ channels,
328
+ channels,
329
+ kernel_size,
330
+ 1,
331
+ dilation=dilation[0],
332
+ padding=get_padding(kernel_size, dilation[0]),
333
+ )
334
+ ),
335
+ weight_norm(
336
+ Conv1d(
337
+ channels,
338
+ channels,
339
+ kernel_size,
340
+ 1,
341
+ dilation=dilation[1],
342
+ padding=get_padding(kernel_size, dilation[1]),
343
+ )
344
+ ),
345
+ ]
346
+ )
347
+ self.convs.apply(init_weights)
348
+
349
+ def forward(self, x, x_mask=None):
350
+ for c in self.convs:
351
+ xt = F.leaky_relu(x, LRELU_SLOPE)
352
+ if x_mask is not None:
353
+ xt = xt * x_mask
354
+ xt = c(xt)
355
+ x = xt + x
356
+ if x_mask is not None:
357
+ x = x * x_mask
358
+ return x
359
+
360
+ def remove_weight_norm(self):
361
+ for l in self.convs:
362
+ remove_weight_norm(l)
363
+
364
+
365
+ class Log(nn.Module):
366
+ def forward(self, x, x_mask, reverse=False, **kwargs):
367
+ if not reverse:
368
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
369
+ logdet = torch.sum(-y, [1, 2])
370
+ return y, logdet
371
+ else:
372
+ x = torch.exp(x) * x_mask
373
+ return x
374
+
375
+
376
+ class Flip(nn.Module):
377
+ def forward(self, x, *args, reverse=False, **kwargs):
378
+ x = torch.flip(x, [1])
379
+ if not reverse:
380
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
381
+ return x, logdet
382
+ else:
383
+ return x
384
+
385
+
386
+ class ElementwiseAffine(nn.Module):
387
+ def __init__(self, channels):
388
+ super().__init__()
389
+ self.channels = channels
390
+ self.m = nn.Parameter(torch.zeros(channels, 1))
391
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
392
+
393
+ def forward(self, x, x_mask, reverse=False, **kwargs):
394
+ if not reverse:
395
+ y = self.m + torch.exp(self.logs) * x
396
+ y = y * x_mask
397
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
398
+ return y, logdet
399
+ else:
400
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
401
+ return x
402
+
403
+
404
+ class ResidualCouplingLayer(nn.Module):
405
+ def __init__(
406
+ self,
407
+ channels,
408
+ hidden_channels,
409
+ kernel_size,
410
+ dilation_rate,
411
+ n_layers,
412
+ p_dropout=0,
413
+ gin_channels=0,
414
+ mean_only=False,
415
+ ):
416
+ assert channels % 2 == 0, "channels should be divisible by 2"
417
+ super().__init__()
418
+ self.channels = channels
419
+ self.hidden_channels = hidden_channels
420
+ self.kernel_size = kernel_size
421
+ self.dilation_rate = dilation_rate
422
+ self.n_layers = n_layers
423
+ self.half_channels = channels // 2
424
+ self.mean_only = mean_only
425
+
426
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
427
+ self.enc = WN(
428
+ hidden_channels,
429
+ kernel_size,
430
+ dilation_rate,
431
+ n_layers,
432
+ p_dropout=p_dropout,
433
+ gin_channels=gin_channels,
434
+ )
435
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
436
+ self.post.weight.data.zero_()
437
+ self.post.bias.data.zero_()
438
+
439
+ def forward(self, x, x_mask, g=None, reverse=False):
440
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
441
+ h = self.pre(x0) * x_mask
442
+ h = self.enc(h, x_mask, g=g)
443
+ stats = self.post(h) * x_mask
444
+ if not self.mean_only:
445
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
446
+ else:
447
+ m = stats
448
+ logs = torch.zeros_like(m)
449
+
450
+ if not reverse:
451
+ x1 = m + x1 * torch.exp(logs) * x_mask
452
+ x = torch.cat([x0, x1], 1)
453
+ logdet = torch.sum(logs, [1, 2])
454
+ return x, logdet
455
+ else:
456
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
457
+ x = torch.cat([x0, x1], 1)
458
+ return x
459
+
460
+
461
+ class ConvFlow(nn.Module):
462
+ def __init__(
463
+ self,
464
+ in_channels,
465
+ filter_channels,
466
+ kernel_size,
467
+ n_layers,
468
+ num_bins=10,
469
+ tail_bound=5.0,
470
+ ):
471
+ super().__init__()
472
+ self.in_channels = in_channels
473
+ self.filter_channels = filter_channels
474
+ self.kernel_size = kernel_size
475
+ self.n_layers = n_layers
476
+ self.num_bins = num_bins
477
+ self.tail_bound = tail_bound
478
+ self.half_channels = in_channels // 2
479
+
480
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
481
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
482
+ self.proj = nn.Conv1d(
483
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
484
+ )
485
+ self.proj.weight.data.zero_()
486
+ self.proj.bias.data.zero_()
487
+
488
+ def forward(self, x, x_mask, g=None, reverse=False):
489
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
490
+ h = self.pre(x0)
491
+ h = self.convs(h, x_mask, g=g)
492
+ h = self.proj(h) * x_mask
493
+
494
+ b, c, t = x0.shape
495
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
496
+
497
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
498
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
499
+ self.filter_channels
500
+ )
501
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
502
+
503
+ x1, logabsdet = piecewise_rational_quadratic_transform(
504
+ x1,
505
+ unnormalized_widths,
506
+ unnormalized_heights,
507
+ unnormalized_derivatives,
508
+ inverse=reverse,
509
+ tails="linear",
510
+ tail_bound=self.tail_bound,
511
+ )
512
+
513
+ x = torch.cat([x0, x1], 1) * x_mask
514
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
515
+ if not reverse:
516
+ return x, logdet
517
+ else:
518
+ return x
519
+
520
+
521
+ class LinearNorm(nn.Module):
522
+ def __init__(
523
+ self,
524
+ in_channels,
525
+ out_channels,
526
+ bias=True,
527
+ spectral_norm=False,
528
+ ):
529
+ super(LinearNorm, self).__init__()
530
+ self.fc = nn.Linear(in_channels, out_channels, bias)
531
+
532
+ if spectral_norm:
533
+ self.fc = nn.utils.spectral_norm(self.fc)
534
+
535
+ def forward(self, input):
536
+ out = self.fc(input)
537
+ return out
538
+
539
+
540
+ class Mish(nn.Module):
541
+ def __init__(self):
542
+ super(Mish, self).__init__()
543
+
544
+ def forward(self, x):
545
+ return x * torch.tanh(F.softplus(x))
546
+
547
+
548
+ class Conv1dGLU(nn.Module):
549
+ """
550
+ Conv1d + GLU(Gated Linear Unit) with residual connection.
551
+ For GLU refer to https://arxiv.org/abs/1612.08083 paper.
552
+ """
553
+
554
+ def __init__(self, in_channels, out_channels, kernel_size, dropout):
555
+ super(Conv1dGLU, self).__init__()
556
+ self.out_channels = out_channels
557
+ self.conv1 = ConvNorm(in_channels, 2 * out_channels, kernel_size=kernel_size)
558
+ self.dropout = nn.Dropout(dropout)
559
+
560
+ def forward(self, x):
561
+ residual = x
562
+ x = self.conv1(x)
563
+ x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1)
564
+ x = x1 * torch.sigmoid(x2)
565
+ x = residual + self.dropout(x)
566
+ return x
567
+
568
+
569
+ class ConvNorm(nn.Module):
570
+ def __init__(
571
+ self,
572
+ in_channels,
573
+ out_channels,
574
+ kernel_size=1,
575
+ stride=1,
576
+ padding=None,
577
+ dilation=1,
578
+ bias=True,
579
+ spectral_norm=False,
580
+ ):
581
+ super(ConvNorm, self).__init__()
582
+
583
+ if padding is None:
584
+ assert kernel_size % 2 == 1
585
+ padding = int(dilation * (kernel_size - 1) / 2)
586
+
587
+ self.conv = torch.nn.Conv1d(
588
+ in_channels,
589
+ out_channels,
590
+ kernel_size=kernel_size,
591
+ stride=stride,
592
+ padding=padding,
593
+ dilation=dilation,
594
+ bias=bias,
595
+ )
596
+
597
+ if spectral_norm:
598
+ self.conv = nn.utils.spectral_norm(self.conv)
599
+
600
+ def forward(self, input):
601
+ out = self.conv(input)
602
+ return out
603
+
604
+
605
+ class MultiHeadAttention(nn.Module):
606
+ """Multi-Head Attention module"""
607
+
608
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.0, spectral_norm=False):
609
+ super().__init__()
610
+
611
+ self.n_head = n_head
612
+ self.d_k = d_k
613
+ self.d_v = d_v
614
+
615
+ self.w_qs = nn.Linear(d_model, n_head * d_k)
616
+ self.w_ks = nn.Linear(d_model, n_head * d_k)
617
+ self.w_vs = nn.Linear(d_model, n_head * d_v)
618
+
619
+ self.attention = ScaledDotProductAttention(
620
+ temperature=np.power(d_model, 0.5), dropout=dropout
621
+ )
622
+
623
+ self.fc = nn.Linear(n_head * d_v, d_model)
624
+ self.dropout = nn.Dropout(dropout)
625
+
626
+ if spectral_norm:
627
+ self.w_qs = nn.utils.spectral_norm(self.w_qs)
628
+ self.w_ks = nn.utils.spectral_norm(self.w_ks)
629
+ self.w_vs = nn.utils.spectral_norm(self.w_vs)
630
+ self.fc = nn.utils.spectral_norm(self.fc)
631
+
632
+ def forward(self, x, mask=None):
633
+ d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
634
+ sz_b, len_x, _ = x.size()
635
+
636
+ residual = x
637
+
638
+ q = self.w_qs(x).view(sz_b, len_x, n_head, d_k)
639
+ k = self.w_ks(x).view(sz_b, len_x, n_head, d_k)
640
+ v = self.w_vs(x).view(sz_b, len_x, n_head, d_v)
641
+ q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k) # (n*b) x lq x dk
642
+ k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k) # (n*b) x lk x dk
643
+ v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_v) # (n*b) x lv x dv
644
+
645
+ if mask is not None:
646
+ slf_mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
647
+ else:
648
+ slf_mask = None
649
+ output, attn = self.attention(q, k, v, mask=slf_mask)
650
+
651
+ output = output.view(n_head, sz_b, len_x, d_v)
652
+ output = (
653
+ output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1)
654
+ ) # b x lq x (n*dv)
655
+
656
+ output = self.fc(output)
657
+
658
+ output = self.dropout(output) + residual
659
+ return output, attn
660
+
661
+
662
+ class ScaledDotProductAttention(nn.Module):
663
+ """Scaled Dot-Product Attention"""
664
+
665
+ def __init__(self, temperature, dropout):
666
+ super().__init__()
667
+ self.temperature = temperature
668
+ self.softmax = nn.Softmax(dim=2)
669
+ self.dropout = nn.Dropout(dropout)
670
+
671
+ def forward(self, q, k, v, mask=None):
672
+ attn = torch.bmm(q, k.transpose(1, 2))
673
+ attn = attn / self.temperature
674
+
675
+ if mask is not None:
676
+ attn = attn.masked_fill(mask, -np.inf)
677
+
678
+ attn = self.softmax(attn)
679
+ p_attn = self.dropout(attn)
680
+
681
+ output = torch.bmm(p_attn, v)
682
+ return output, attn
683
+
684
+
685
+ class MelStyleEncoder(nn.Module):
686
+ """MelStyleEncoder"""
687
+
688
+ def __init__(
689
+ self,
690
+ n_mel_channels=80,
691
+ style_hidden=128,
692
+ style_vector_dim=256,
693
+ style_kernel_size=5,
694
+ style_head=2,
695
+ dropout=0.1,
696
+ ):
697
+ super(MelStyleEncoder, self).__init__()
698
+ self.in_dim = n_mel_channels
699
+ self.hidden_dim = style_hidden
700
+ self.out_dim = style_vector_dim
701
+ self.kernel_size = style_kernel_size
702
+ self.n_head = style_head
703
+ self.dropout = dropout
704
+
705
+ self.spectral = nn.Sequential(
706
+ LinearNorm(self.in_dim, self.hidden_dim),
707
+ Mish(),
708
+ nn.Dropout(self.dropout),
709
+ LinearNorm(self.hidden_dim, self.hidden_dim),
710
+ Mish(),
711
+ nn.Dropout(self.dropout),
712
+ )
713
+
714
+ self.temporal = nn.Sequential(
715
+ Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
716
+ Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
717
+ )
718
+
719
+ self.slf_attn = MultiHeadAttention(
720
+ self.n_head,
721
+ self.hidden_dim,
722
+ self.hidden_dim // self.n_head,
723
+ self.hidden_dim // self.n_head,
724
+ self.dropout,
725
+ )
726
+
727
+ self.fc = LinearNorm(self.hidden_dim, self.out_dim)
728
+
729
+ def temporal_avg_pool(self, x, mask=None):
730
+ if mask is None:
731
+ out = torch.mean(x, dim=1)
732
+ else:
733
+ len_ = (~mask).sum(dim=1).unsqueeze(1)
734
+ x = x.masked_fill(mask.unsqueeze(-1), 0)
735
+ x = x.sum(dim=1)
736
+ out = torch.div(x, len_)
737
+ return out
738
+
739
+ def forward(self, x, mask=None):
740
+ x = x.transpose(1, 2)
741
+ if mask is not None:
742
+ mask = (mask.int() == 0).squeeze(1)
743
+ max_len = x.shape[1]
744
+ slf_attn_mask = (
745
+ mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
746
+ )
747
+
748
+ # spectral
749
+ x = self.spectral(x)
750
+ # temporal
751
+ x = x.transpose(1, 2)
752
+ x = self.temporal(x)
753
+ x = x.transpose(1, 2)
754
+ # self-attention
755
+ if mask is not None:
756
+ x = x.masked_fill(mask.unsqueeze(-1), 0)
757
+ x, _ = self.slf_attn(x, mask=slf_attn_mask)
758
+ # fc
759
+ x = self.fc(x)
760
+ # temoral average pooling
761
+ w = self.temporal_avg_pool(x, mask=mask)
762
+
763
+ return w.unsqueeze(-1)
764
+
765
+
766
+ class MelStyleEncoderVAE(nn.Module):
767
+ def __init__(self, spec_channels, z_latent_dim, emb_dim):
768
+ super().__init__()
769
+ self.ref_encoder = MelStyleEncoder(spec_channels, style_vector_dim=emb_dim)
770
+ self.fc1 = nn.Linear(emb_dim, z_latent_dim)
771
+ self.fc2 = nn.Linear(emb_dim, z_latent_dim)
772
+ self.fc3 = nn.Linear(z_latent_dim, emb_dim)
773
+ self.z_latent_dim = z_latent_dim
774
+
775
+ def reparameterize(self, mu, logvar):
776
+ if self.training:
777
+ std = torch.exp(0.5 * logvar)
778
+ eps = torch.randn_like(std)
779
+ return eps.mul(std).add_(mu)
780
+ else:
781
+ return mu
782
+
783
+ def forward(self, inputs, mask=None):
784
+ enc_out = self.ref_encoder(inputs.squeeze(-1), mask).squeeze(-1)
785
+ mu = self.fc1(enc_out)
786
+ logvar = self.fc2(enc_out)
787
+ posterior = D.Normal(mu, torch.exp(logvar))
788
+ kl_divergence = D.kl_divergence(
789
+ posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar))
790
+ )
791
+ loss_kl = kl_divergence.mean()
792
+
793
+ z = posterior.rsample()
794
+ style_embed = self.fc3(z)
795
+
796
+ return style_embed.unsqueeze(-1), loss_kl
797
+
798
+ def infer(self, inputs=None, random_sample=False, manual_latent=None):
799
+ if manual_latent is None:
800
+ if random_sample:
801
+ dev = next(self.parameters()).device
802
+ posterior = D.Normal(
803
+ torch.zeros(1, self.z_latent_dim, device=dev),
804
+ torch.ones(1, self.z_latent_dim, device=dev),
805
+ )
806
+ z = posterior.rsample()
807
+ else:
808
+ enc_out = self.ref_encoder(inputs.transpose(1, 2))
809
+ mu = self.fc1(enc_out)
810
+ z = mu
811
+ else:
812
+ z = manual_latent
813
+ style_embed = self.fc3(z)
814
+ return style_embed.unsqueeze(-1), z
815
+
816
+
817
+ class ActNorm(nn.Module):
818
+ def __init__(self, channels, ddi=False, **kwargs):
819
+ super().__init__()
820
+ self.channels = channels
821
+ self.initialized = not ddi
822
+
823
+ self.logs = nn.Parameter(torch.zeros(1, channels, 1))
824
+ self.bias = nn.Parameter(torch.zeros(1, channels, 1))
825
+
826
+ def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs):
827
+ if x_mask is None:
828
+ x_mask = torch.ones(x.size(0), 1, x.size(2)).to(
829
+ device=x.device, dtype=x.dtype
830
+ )
831
+ x_len = torch.sum(x_mask, [1, 2])
832
+ if not self.initialized:
833
+ self.initialize(x, x_mask)
834
+ self.initialized = True
835
+
836
+ if reverse:
837
+ z = (x - self.bias) * torch.exp(-self.logs) * x_mask
838
+ logdet = None
839
+ return z
840
+ else:
841
+ z = (self.bias + torch.exp(self.logs) * x) * x_mask
842
+ logdet = torch.sum(self.logs) * x_len # [b]
843
+ return z, logdet
844
+
845
+ def store_inverse(self):
846
+ pass
847
+
848
+ def set_ddi(self, ddi):
849
+ self.initialized = not ddi
850
+
851
+ def initialize(self, x, x_mask):
852
+ with torch.no_grad():
853
+ denom = torch.sum(x_mask, [0, 2])
854
+ m = torch.sum(x * x_mask, [0, 2]) / denom
855
+ m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
856
+ v = m_sq - (m**2)
857
+ logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
858
+
859
+ bias_init = (
860
+ (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
861
+ )
862
+ logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
863
+
864
+ self.bias.data.copy_(bias_init)
865
+ self.logs.data.copy_(logs_init)
866
+
867
+
868
+ class InvConvNear(nn.Module):
869
+ def __init__(self, channels, n_split=4, no_jacobian=False, **kwargs):
870
+ super().__init__()
871
+ assert n_split % 2 == 0
872
+ self.channels = channels
873
+ self.n_split = n_split
874
+ self.no_jacobian = no_jacobian
875
+
876
+ w_init = torch.linalg.qr(
877
+ torch.FloatTensor(self.n_split, self.n_split).normal_()
878
+ )[0]
879
+ if torch.det(w_init) < 0:
880
+ w_init[:, 0] = -1 * w_init[:, 0]
881
+ self.weight = nn.Parameter(w_init)
882
+
883
+ def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs):
884
+ b, c, t = x.size()
885
+ assert c % self.n_split == 0
886
+ if x_mask is None:
887
+ x_mask = 1
888
+ x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
889
+ else:
890
+ x_len = torch.sum(x_mask, [1, 2])
891
+
892
+ x = x.view(b, 2, c // self.n_split, self.n_split // 2, t)
893
+ x = (
894
+ x.permute(0, 1, 3, 2, 4)
895
+ .contiguous()
896
+ .view(b, self.n_split, c // self.n_split, t)
897
+ )
898
+
899
+ if reverse:
900
+ if hasattr(self, "weight_inv"):
901
+ weight = self.weight_inv
902
+ else:
903
+ weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
904
+ logdet = None
905
+ else:
906
+ weight = self.weight
907
+ if self.no_jacobian:
908
+ logdet = 0
909
+ else:
910
+ logdet = torch.logdet(self.weight) * (c / self.n_split) * x_len # [b]
911
+
912
+ weight = weight.view(self.n_split, self.n_split, 1, 1)
913
+ z = F.conv2d(x, weight)
914
+
915
+ z = z.view(b, 2, self.n_split // 2, c // self.n_split, t)
916
+ z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask
917
+ if reverse:
918
+ return z
919
+ else:
920
+ return z, logdet
921
+
922
+ def store_inverse(self):
923
+ self.weight_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
module/mrte_model.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is Multi-reference timbre encoder
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.utils import remove_weight_norm, weight_norm
6
+ from module.attentions import MultiHeadAttention
7
+
8
+
9
+ class MRTE(nn.Module):
10
+ def __init__(
11
+ self,
12
+ content_enc_channels=192,
13
+ hidden_size=512,
14
+ out_channels=192,
15
+ kernel_size=5,
16
+ n_heads=4,
17
+ ge_layer=2,
18
+ ):
19
+ super(MRTE, self).__init__()
20
+ self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)
21
+ self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
22
+ self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
23
+ self.c_post = nn.Conv1d(hidden_size, out_channels, 1)
24
+
25
+ def forward(self, ssl_enc, ssl_mask, text, text_mask, ge, test=None):
26
+ if ge == None:
27
+ ge = 0
28
+ attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)
29
+
30
+ ssl_enc = self.c_pre(ssl_enc * ssl_mask)
31
+ text_enc = self.text_pre(text * text_mask)
32
+ if test != None:
33
+ if test == 0:
34
+ x = (
35
+ self.cross_attention(
36
+ ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
37
+ )
38
+ + ssl_enc
39
+ + ge
40
+ )
41
+ elif test == 1:
42
+ x = ssl_enc + ge
43
+ elif test == 2:
44
+ x = (
45
+ self.cross_attention(
46
+ ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask
47
+ )
48
+ + ge
49
+ )
50
+ else:
51
+ raise ValueError("test should be 0,1,2")
52
+ else:
53
+ x = (
54
+ self.cross_attention(
55
+ ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
56
+ )
57
+ + ssl_enc
58
+ + ge
59
+ )
60
+ x = self.c_post(x * ssl_mask)
61
+ return x
62
+
63
+
64
+ class SpeakerEncoder(torch.nn.Module):
65
+ def __init__(
66
+ self,
67
+ mel_n_channels=80,
68
+ model_num_layers=2,
69
+ model_hidden_size=256,
70
+ model_embedding_size=256,
71
+ ):
72
+ super(SpeakerEncoder, self).__init__()
73
+ self.lstm = nn.LSTM(
74
+ mel_n_channels, model_hidden_size, model_num_layers, batch_first=True
75
+ )
76
+ self.linear = nn.Linear(model_hidden_size, model_embedding_size)
77
+ self.relu = nn.ReLU()
78
+
79
+ def forward(self, mels):
80
+ self.lstm.flatten_parameters()
81
+ _, (hidden, _) = self.lstm(mels.transpose(-1, -2))
82
+ embeds_raw = self.relu(self.linear(hidden[-1]))
83
+ return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
84
+
85
+
86
+ class MELEncoder(nn.Module):
87
+ def __init__(
88
+ self,
89
+ in_channels,
90
+ out_channels,
91
+ hidden_channels,
92
+ kernel_size,
93
+ dilation_rate,
94
+ n_layers,
95
+ ):
96
+ super().__init__()
97
+ self.in_channels = in_channels
98
+ self.out_channels = out_channels
99
+ self.hidden_channels = hidden_channels
100
+ self.kernel_size = kernel_size
101
+ self.dilation_rate = dilation_rate
102
+ self.n_layers = n_layers
103
+
104
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
105
+ self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers)
106
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
107
+
108
+ def forward(self, x):
109
+ # print(x.shape,x_lengths.shape)
110
+ x = self.pre(x)
111
+ x = self.enc(x)
112
+ x = self.proj(x)
113
+ return x
114
+
115
+
116
+ class WN(torch.nn.Module):
117
+ def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers):
118
+ super(WN, self).__init__()
119
+ assert kernel_size % 2 == 1
120
+ self.hidden_channels = hidden_channels
121
+ self.kernel_size = kernel_size
122
+ self.dilation_rate = dilation_rate
123
+ self.n_layers = n_layers
124
+
125
+ self.in_layers = torch.nn.ModuleList()
126
+ self.res_skip_layers = torch.nn.ModuleList()
127
+
128
+ for i in range(n_layers):
129
+ dilation = dilation_rate**i
130
+ padding = int((kernel_size * dilation - dilation) / 2)
131
+ in_layer = nn.Conv1d(
132
+ hidden_channels,
133
+ 2 * hidden_channels,
134
+ kernel_size,
135
+ dilation=dilation,
136
+ padding=padding,
137
+ )
138
+ in_layer = weight_norm(in_layer)
139
+ self.in_layers.append(in_layer)
140
+
141
+ # last one is not necessary
142
+ if i < n_layers - 1:
143
+ res_skip_channels = 2 * hidden_channels
144
+ else:
145
+ res_skip_channels = hidden_channels
146
+
147
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
148
+ res_skip_layer = weight_norm(res_skip_layer, name="weight")
149
+ self.res_skip_layers.append(res_skip_layer)
150
+
151
+ def forward(self, x):
152
+ output = torch.zeros_like(x)
153
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
154
+
155
+ for i in range(self.n_layers):
156
+ x_in = self.in_layers[i](x)
157
+
158
+ acts = fused_add_tanh_sigmoid_multiply(x_in, n_channels_tensor)
159
+
160
+ res_skip_acts = self.res_skip_layers[i](acts)
161
+ if i < self.n_layers - 1:
162
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
163
+ x = x + res_acts
164
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
165
+ else:
166
+ output = output + res_skip_acts
167
+ return output
168
+
169
+ def remove_weight_norm(self):
170
+ for l in self.in_layers:
171
+ remove_weight_norm(l)
172
+ for l in self.res_skip_layers:
173
+ remove_weight_norm(l)
174
+
175
+
176
+ @torch.jit.script
177
+ def fused_add_tanh_sigmoid_multiply(input, n_channels):
178
+ n_channels_int = n_channels[0]
179
+ t_act = torch.tanh(input[:, :n_channels_int, :])
180
+ s_act = torch.sigmoid(input[:, n_channels_int:, :])
181
+ acts = t_act * s_act
182
+ return acts
183
+
184
+
185
+ if __name__ == "__main__":
186
+ content_enc = torch.randn(3, 192, 100)
187
+ content_mask = torch.ones(3, 1, 100)
188
+ ref_mel = torch.randn(3, 128, 30)
189
+ ref_mask = torch.ones(3, 1, 30)
190
+ model = MRTE()
191
+ out = model(content_enc, content_mask, ref_mel, ref_mask)
192
+ print(out.shape)
module/quantize.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Residual vector quantizer implementation."""
8
+
9
+ from dataclasses import dataclass, field
10
+ import math
11
+ import typing as tp
12
+
13
+ import torch
14
+ from torch import nn
15
+
16
+ from module.core_vq import ResidualVectorQuantization
17
+
18
+
19
+ @dataclass
20
+ class QuantizedResult:
21
+ quantized: torch.Tensor
22
+ codes: torch.Tensor
23
+ bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
24
+ penalty: tp.Optional[torch.Tensor] = None
25
+ metrics: dict = field(default_factory=dict)
26
+
27
+
28
+ class ResidualVectorQuantizer(nn.Module):
29
+ """Residual Vector Quantizer.
30
+ Args:
31
+ dimension (int): Dimension of the codebooks.
32
+ n_q (int): Number of residual vector quantizers used.
33
+ bins (int): Codebook size.
34
+ decay (float): Decay for exponential moving average over the codebooks.
35
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
36
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
37
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
38
+ that have an exponential moving average cluster size less than the specified threshold with
39
+ randomly selected vector from the current batch.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ dimension: int = 256,
45
+ n_q: int = 8,
46
+ bins: int = 1024,
47
+ decay: float = 0.99,
48
+ kmeans_init: bool = True,
49
+ kmeans_iters: int = 50,
50
+ threshold_ema_dead_code: int = 2,
51
+ ):
52
+ super().__init__()
53
+ self.n_q = n_q
54
+ self.dimension = dimension
55
+ self.bins = bins
56
+ self.decay = decay
57
+ self.kmeans_init = kmeans_init
58
+ self.kmeans_iters = kmeans_iters
59
+ self.threshold_ema_dead_code = threshold_ema_dead_code
60
+ self.vq = ResidualVectorQuantization(
61
+ dim=self.dimension,
62
+ codebook_size=self.bins,
63
+ num_quantizers=self.n_q,
64
+ decay=self.decay,
65
+ kmeans_init=self.kmeans_init,
66
+ kmeans_iters=self.kmeans_iters,
67
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
68
+ )
69
+
70
+ def forward(
71
+ self,
72
+ x: torch.Tensor,
73
+ n_q: tp.Optional[int] = None,
74
+ layers: tp.Optional[list] = None,
75
+ ) -> QuantizedResult:
76
+ """Residual vector quantization on the given input tensor.
77
+ Args:
78
+ x (torch.Tensor): Input tensor.
79
+ n_q (int): Number of quantizer used to quantize. Default: All quantizers.
80
+ layers (list): Layer that need to return quantized. Defalt: None.
81
+ Returns:
82
+ QuantizedResult:
83
+ The quantized (or approximately quantized) representation with
84
+ the associated numbert quantizers and layer quantized required to return.
85
+ """
86
+ n_q = n_q if n_q else self.n_q
87
+ if layers and max(layers) >= n_q:
88
+ raise ValueError(
89
+ f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
90
+ )
91
+ quantized, codes, commit_loss, quantized_list = self.vq(
92
+ x, n_q=n_q, layers=layers
93
+ )
94
+ return quantized, codes, torch.mean(commit_loss), quantized_list
95
+
96
+ def encode(
97
+ self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
98
+ ) -> torch.Tensor:
99
+ """Encode a given input tensor with the specified sample rate at the given bandwidth.
100
+ The RVQ encode method sets the appropriate number of quantizer to use
101
+ and returns indices for each quantizer.
102
+ Args:
103
+ x (torch.Tensor): Input tensor.
104
+ n_q (int): Number of quantizer used to quantize. Default: All quantizers.
105
+ st (int): Start to encode input from which layers. Default: 0.
106
+ """
107
+ n_q = n_q if n_q else self.n_q
108
+ st = st or 0
109
+ codes = self.vq.encode(x, n_q=n_q, st=st)
110
+ return codes
111
+
112
+ def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor:
113
+ """Decode the given codes to the quantized representation.
114
+ Args:
115
+ codes (torch.Tensor): Input indices for each quantizer.
116
+ st (int): Start to decode input codes from which layers. Default: 0.
117
+ """
118
+ quantized = self.vq.decode(codes, st=st)
119
+ return quantized
module/transforms.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+ import numpy as np
5
+
6
+
7
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
8
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
+ DEFAULT_MIN_DERIVATIVE = 1e-3
10
+
11
+
12
+ def piecewise_rational_quadratic_transform(
13
+ inputs,
14
+ unnormalized_widths,
15
+ unnormalized_heights,
16
+ unnormalized_derivatives,
17
+ inverse=False,
18
+ tails=None,
19
+ tail_bound=1.0,
20
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
21
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
22
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
23
+ ):
24
+ if tails is None:
25
+ spline_fn = rational_quadratic_spline
26
+ spline_kwargs = {}
27
+ else:
28
+ spline_fn = unconstrained_rational_quadratic_spline
29
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
30
+
31
+ outputs, logabsdet = spline_fn(
32
+ inputs=inputs,
33
+ unnormalized_widths=unnormalized_widths,
34
+ unnormalized_heights=unnormalized_heights,
35
+ unnormalized_derivatives=unnormalized_derivatives,
36
+ inverse=inverse,
37
+ min_bin_width=min_bin_width,
38
+ min_bin_height=min_bin_height,
39
+ min_derivative=min_derivative,
40
+ **spline_kwargs
41
+ )
42
+ return outputs, logabsdet
43
+
44
+
45
+ def searchsorted(bin_locations, inputs, eps=1e-6):
46
+ bin_locations[..., -1] += eps
47
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
48
+
49
+
50
+ def unconstrained_rational_quadratic_spline(
51
+ inputs,
52
+ unnormalized_widths,
53
+ unnormalized_heights,
54
+ unnormalized_derivatives,
55
+ inverse=False,
56
+ tails="linear",
57
+ tail_bound=1.0,
58
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
59
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
60
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
61
+ ):
62
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
63
+ outside_interval_mask = ~inside_interval_mask
64
+
65
+ outputs = torch.zeros_like(inputs)
66
+ logabsdet = torch.zeros_like(inputs)
67
+
68
+ if tails == "linear":
69
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
70
+ constant = np.log(np.exp(1 - min_derivative) - 1)
71
+ unnormalized_derivatives[..., 0] = constant
72
+ unnormalized_derivatives[..., -1] = constant
73
+
74
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
75
+ logabsdet[outside_interval_mask] = 0
76
+ else:
77
+ raise RuntimeError("{} tails are not implemented.".format(tails))
78
+
79
+ (
80
+ outputs[inside_interval_mask],
81
+ logabsdet[inside_interval_mask],
82
+ ) = rational_quadratic_spline(
83
+ inputs=inputs[inside_interval_mask],
84
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
+ inverse=inverse,
88
+ left=-tail_bound,
89
+ right=tail_bound,
90
+ bottom=-tail_bound,
91
+ top=tail_bound,
92
+ min_bin_width=min_bin_width,
93
+ min_bin_height=min_bin_height,
94
+ min_derivative=min_derivative,
95
+ )
96
+
97
+ return outputs, logabsdet
98
+
99
+
100
+ def rational_quadratic_spline(
101
+ inputs,
102
+ unnormalized_widths,
103
+ unnormalized_heights,
104
+ unnormalized_derivatives,
105
+ inverse=False,
106
+ left=0.0,
107
+ right=1.0,
108
+ bottom=0.0,
109
+ top=1.0,
110
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
111
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
112
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
113
+ ):
114
+ if torch.min(inputs) < left or torch.max(inputs) > right:
115
+ raise ValueError("Input to a transform is not within its domain")
116
+
117
+ num_bins = unnormalized_widths.shape[-1]
118
+
119
+ if min_bin_width * num_bins > 1.0:
120
+ raise ValueError("Minimal bin width too large for the number of bins")
121
+ if min_bin_height * num_bins > 1.0:
122
+ raise ValueError("Minimal bin height too large for the number of bins")
123
+
124
+ widths = F.softmax(unnormalized_widths, dim=-1)
125
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
126
+ cumwidths = torch.cumsum(widths, dim=-1)
127
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
128
+ cumwidths = (right - left) * cumwidths + left
129
+ cumwidths[..., 0] = left
130
+ cumwidths[..., -1] = right
131
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
132
+
133
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
134
+
135
+ heights = F.softmax(unnormalized_heights, dim=-1)
136
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
137
+ cumheights = torch.cumsum(heights, dim=-1)
138
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
139
+ cumheights = (top - bottom) * cumheights + bottom
140
+ cumheights[..., 0] = bottom
141
+ cumheights[..., -1] = top
142
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
143
+
144
+ if inverse:
145
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
146
+ else:
147
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
148
+
149
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
150
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
151
+
152
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
153
+ delta = heights / widths
154
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
155
+
156
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
157
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
158
+
159
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
160
+
161
+ if inverse:
162
+ a = (inputs - input_cumheights) * (
163
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
164
+ ) + input_heights * (input_delta - input_derivatives)
165
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
166
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
167
+ )
168
+ c = -input_delta * (inputs - input_cumheights)
169
+
170
+ discriminant = b.pow(2) - 4 * a * c
171
+ assert (discriminant >= 0).all()
172
+
173
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
174
+ outputs = root * input_bin_widths + input_cumwidths
175
+
176
+ theta_one_minus_theta = root * (1 - root)
177
+ denominator = input_delta + (
178
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
179
+ * theta_one_minus_theta
180
+ )
181
+ derivative_numerator = input_delta.pow(2) * (
182
+ input_derivatives_plus_one * root.pow(2)
183
+ + 2 * input_delta * theta_one_minus_theta
184
+ + input_derivatives * (1 - root).pow(2)
185
+ )
186
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
187
+
188
+ return outputs, -logabsdet
189
+ else:
190
+ theta = (inputs - input_cumwidths) / input_bin_widths
191
+ theta_one_minus_theta = theta * (1 - theta)
192
+
193
+ numerator = input_heights * (
194
+ input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
195
+ )
196
+ denominator = input_delta + (
197
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
198
+ * theta_one_minus_theta
199
+ )
200
+ outputs = input_cumheights + numerator / denominator
201
+
202
+ derivative_numerator = input_delta.pow(2) * (
203
+ input_derivatives_plus_one * theta.pow(2)
204
+ + 2 * input_delta * theta_one_minus_theta
205
+ + input_derivatives * (1 - theta).pow(2)
206
+ )
207
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
208
+
209
+ return outputs, logabsdet
my_utils.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ffmpeg
2
+ import numpy as np
3
+
4
+
5
+ def load_audio(file, sr):
6
+ try:
7
+ # https://github.com/openai/whisper/blob/main/whisper/audio.py#L26
8
+ # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
9
+ # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
10
+ file = (
11
+ file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
12
+ ) # 防止小白拷路径头尾带了空格和"和回车
13
+ out, _ = (
14
+ ffmpeg.input(file, threads=0)
15
+ .output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr)
16
+ .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
17
+ )
18
+ except Exception as e:
19
+ raise RuntimeError(f"Failed to load audio: {e}")
20
+
21
+ return np.frombuffer(out, np.float32).flatten()