admin117 commited on
Commit
bd3dc09
·
1 Parent(s): 86970b7

Initial commit without LFS files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +3 -0
  2. AR/models/embedding.py +98 -0
  3. AR/models/structs.py +91 -0
  4. AR/models/t2s_model_abc.py +538 -0
  5. AR/models/t2s_model_flash_attn.py +408 -0
  6. configs/s1.yaml +31 -0
  7. configs/s1big.yaml +31 -0
  8. configs/s1big2.yaml +31 -0
  9. configs/s1longer-v2.yaml +31 -0
  10. configs/s1longer.yaml +31 -0
  11. configs/s1mq.yaml +77 -0
  12. configs/s2.json +90 -0
  13. configs/train.yaml +32 -0
  14. download.py +5 -0
  15. eres2net/ERes2Net.py +260 -0
  16. eres2net/ERes2NetV2.py +292 -0
  17. eres2net/ERes2Net_huge.py +286 -0
  18. eres2net/fusion.py +29 -0
  19. eres2net/kaldi.py +819 -0
  20. eres2net/pooling_layers.py +104 -0
  21. feature_extractor/__init__.py +6 -0
  22. feature_extractor/cnhubert.py +109 -0
  23. feature_extractor/whisper_enc.py +25 -0
  24. inference_webui.py +867 -0
  25. module/__init__.py +0 -0
  26. module/attentions.py +709 -0
  27. module/attentions_onnx.py +354 -0
  28. module/commons.py +189 -0
  29. module/core_vq.py +383 -0
  30. module/data_utils.py +332 -0
  31. module/losses.py +73 -0
  32. module/mel_processing.py +153 -0
  33. module/models.py +1040 -0
  34. module/models_onnx.py +918 -0
  35. module/modules.py +923 -0
  36. module/mrte_model.py +192 -0
  37. module/quantize.py +119 -0
  38. module/transforms.py +209 -0
  39. packages.txt +1 -0
  40. pre-requirements.txt +2 -0
  41. pretrained_models/chinese-hubert-base/config.json +72 -0
  42. pretrained_models/chinese-hubert-base/preprocessor_config.json +9 -0
  43. pretrained_models/chinese-roberta-wwm-ext-large/config.json +34 -0
  44. pretrained_models/chinese-roberta-wwm-ext-large/tokenizer.json +0 -0
  45. process_ckpt.py +31 -0
  46. requirements.txt +36 -0
  47. sv.py +24 -0
  48. text/.gitignore +3 -0
  49. text/LangSegmenter/__init__.py +1 -0
  50. text/LangSegmenter/langsegmenter.py +175 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.pickle
2
+ text/ja_userdic/user.dict
3
+ text/ja_userdic/userdict.csv
AR/models/embedding.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
2
+ import math
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+
8
+ class TokenEmbedding(nn.Module):
9
+ def __init__(
10
+ self,
11
+ embedding_dim: int,
12
+ vocab_size: int,
13
+ dropout: float = 0.0,
14
+ ):
15
+ super().__init__()
16
+
17
+ self.vocab_size = vocab_size
18
+ self.embedding_dim = embedding_dim
19
+
20
+ self.dropout = torch.nn.Dropout(p=dropout)
21
+ self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
22
+
23
+ @property
24
+ def weight(self) -> torch.Tensor:
25
+ return self.word_embeddings.weight
26
+
27
+ def embedding(self, index: int) -> torch.Tensor:
28
+ return self.word_embeddings.weight[index : index + 1]
29
+
30
+ def forward(self, x: torch.Tensor):
31
+ x = self.word_embeddings(x)
32
+ x = self.dropout(x)
33
+ return x
34
+
35
+
36
+ class SinePositionalEmbeddingNested(nn.Module):
37
+ def __init__(
38
+ self,
39
+ embedding_dim: int,
40
+ dropout: float = 0.0,
41
+ scale: bool = False,
42
+ alpha: bool = False,
43
+ max_batch_size: int = 20,
44
+ max_seq_len: int = 2500,
45
+ ):
46
+ super().__init__()
47
+ self.embedding_dim = embedding_dim
48
+ self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
49
+ self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
50
+ self.dropout = torch.nn.Dropout(p=dropout)
51
+ self.max_batch_size = max_batch_size
52
+ self.max_seq_len = max_seq_len
53
+
54
+ self.reverse = False
55
+ self.register_buffer("pe", torch.zeros(max_batch_size, max_seq_len, embedding_dim), persistent=False)
56
+ self.pe: torch.Tensor
57
+ self.compute_pe()
58
+
59
+ def compute_pe(self):
60
+ """Reset the positional encodings."""
61
+ if self.reverse:
62
+ position = torch.arange(self.max_seq_len - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
63
+ else:
64
+ position = torch.arange(self.max_seq_len, dtype=torch.float32).unsqueeze(1)
65
+ div_term = torch.exp(
66
+ torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
67
+ )
68
+ pe = self.pe
69
+ pe[:, :, 0::2] = torch.sin(position * div_term)
70
+ pe[:, :, 1::2] = torch.cos(position * div_term)
71
+
72
+ def forward(self, input_pos: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
73
+ """
74
+ Args:
75
+ input_pos (Tensor): [batch_size, ]
76
+ x (Tensor): [batch_size, 1, embed_dim]
77
+
78
+ Returns:
79
+ embedded_x (Tensor): [batch_size, 1, embed_dim]
80
+ """
81
+
82
+ batch_size = x.shape[0]
83
+ pe_values = self.pe[torch.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
84
+
85
+ return x * self.x_scale + self.alpha * pe_values.unsqueeze(1) # (batch_size, 1, embed_dim)
86
+
87
+ def prefill(self, x: torch.Tensor) -> torch.Tensor:
88
+ """
89
+ Args:
90
+ x (Tensor): Nested Seqlen [batch_size, seq_len, embed_dim]
91
+
92
+ Returns:
93
+ embedded_x (Tensor): Nested Seqlen [batch_size, seq_len, embed_dim]
94
+ """
95
+
96
+ input_pos: torch.Tensor = torch.tensor([i.shape[0] for i in x.unbind()])
97
+ pe_values = torch.nested.nested_tensor([self.pe[i, : input_pos[i], :] for i in range(input_pos.size(0))])
98
+ return x * self.x_scale + self.alpha.item() * pe_values
AR/models/structs.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified From https://github.com/XXXXRT666/GPT-SoVITS
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from dataclasses import dataclass
8
+ from typing import List, Literal, MutableSequence, Optional
9
+
10
+ import torch
11
+
12
+ from AR.models.t2s_model_abc import KVCacheABC, Sampler, T2SDecoderABC
13
+
14
+ Tensor = torch.Tensor
15
+
16
+
17
+ @dataclass
18
+ class T2SResult:
19
+ result: List[Tensor] | None = None
20
+ infer_speed: float = 0.0
21
+ status: Literal["Success", "Error"] = "Success"
22
+ exception: Optional[Exception] = None
23
+ traceback: Optional[str] = None
24
+
25
+
26
+ @dataclass
27
+ class T2SRequest:
28
+ x: List[torch.Tensor]
29
+ x_lens: Tensor
30
+ prompts: torch.Tensor
31
+ bert_feature: List[Tensor]
32
+ valid_length: int
33
+ top_k: int = 5
34
+ top_p: float = 1
35
+ early_stop_num: int = -1
36
+ temperature: float = 1.0
37
+ repetition_penalty: float = 1.35
38
+ use_cuda_graph: bool = False
39
+ debug: bool = False
40
+
41
+
42
+ class T2SSession:
43
+ def __init__(self, decoder: T2SDecoderABC, request: T2SRequest, device: torch.device, dtype: torch.dtype):
44
+ with device:
45
+ self.decoder = decoder
46
+ self.request = request
47
+ self.device = device
48
+ self.dtype = dtype
49
+
50
+ bsz = len(request.x)
51
+ y_len = request.prompts.size(-1)
52
+ self.bsz = bsz
53
+ self.y_len = y_len
54
+
55
+ # Cache
56
+ self.kv_cache: MutableSequence[KVCacheABC]
57
+ self.sampler = Sampler(bsz, decoder.vocab_size)
58
+
59
+ # Forward args
60
+ self.x = request.x
61
+ self.x_lens = request.x_lens.to(torch.int32)
62
+ self.y = request.prompts
63
+ self.bert_feature = request.bert_feature
64
+
65
+ self.prefill_len = self.x_lens + self.y.size(1)
66
+
67
+ self.input_pos = torch.zeros_like(self.prefill_len)
68
+ self.input_pos.add_(self.prefill_len)
69
+
70
+ # CUDA Graph
71
+ self.graph: Optional[torch.cuda.CUDAGraph] = None
72
+ self.xy_pos_: Tensor
73
+ self.xy_dec_: Tensor
74
+
75
+ # EOS
76
+ self.completed = torch.Tensor([False] * len(self.x)).bool().to(device)
77
+ self.y_results: List[Tensor] = [None] * len(self.x) # type: ignore
78
+
79
+ self.xy_pos = decoder.embed(self.x, self.y, self.bert_feature)
80
+
81
+ attn_mask = []
82
+ for bs in range(bsz):
83
+ pos = int(self.x_lens[bs].item())
84
+ mask = torch.zeros(pos + y_len, pos + y_len).bool()
85
+ mask[:, :pos].fill_(True)
86
+ if y_len > 0:
87
+ mask[-y_len:, -y_len:] = ~torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1)
88
+ attn_mask.append(mask)
89
+ self.attn_mask_nested = torch.nested.nested_tensor(attn_mask)
90
+
91
+ self.id: int = -1
AR/models/t2s_model_abc.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified From https://github.com/XXXXRT666/GPT-SoVITS
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import os
8
+ import random
9
+ from abc import ABC, abstractmethod
10
+ from contextlib import nullcontext
11
+ from typing import Any, Dict, List, MutableSequence, Tuple, Type
12
+
13
+ import torch
14
+ import torch._inductor.config
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from torch.cuda.graphs import CUDAGraph
18
+ from torch.profiler import ProfilerAction, tensorboard_trace_handler
19
+
20
+ from AR.models.embedding import (
21
+ SinePositionalEmbeddingNested as SinePositionalEmbedding,
22
+ )
23
+ from AR.models.embedding import TokenEmbedding
24
+
25
+ Tensor = torch.Tensor
26
+
27
+
28
+ class Sampler(nn.Module):
29
+ def __init__(self, batch_size: int, vocab_size: int) -> None:
30
+ super().__init__()
31
+ self.batch_size = batch_size
32
+
33
+ # @torch.jit.script
34
+ def sample(
35
+ self,
36
+ logits: Tensor,
37
+ previous_tokens: Tensor,
38
+ temperature: float,
39
+ top_k: int,
40
+ top_p: float,
41
+ repetition_penalty: float,
42
+ ) -> Tensor:
43
+ previous_tokens = previous_tokens.long()
44
+ score = torch.gather(logits, dim=1, index=previous_tokens)
45
+ score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
46
+ logits.scatter_(dim=1, index=previous_tokens, src=score)
47
+
48
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
49
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
50
+ sorted_indices_to_remove = cum_probs > top_p
51
+ sorted_indices_to_remove[:, 0] = False # keep at least one option
52
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
53
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
54
+
55
+ logits = logits / max(temperature, 1e-5)
56
+
57
+ v, _ = torch.topk(logits, top_k)
58
+ pivot = v[:, -1].unsqueeze(-1)
59
+ logits = torch.where(logits < pivot, -float("Inf"), logits)
60
+
61
+ probs = torch.nn.functional.softmax(logits, dim=-1)
62
+ q = torch.empty_like(probs).exponential_(1.0)
63
+ idx_next = torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int32)
64
+
65
+ return idx_next
66
+
67
+
68
+ class KVCacheABC(ABC, nn.Module):
69
+ def __init__(self, *args, **kwds) -> None:
70
+ super().__init__()
71
+ self.k_cache: Tensor
72
+ self.v_cache: Tensor
73
+ self.n_head: int
74
+ self.head_dim: int
75
+ self.batch_size: int
76
+ self.max_seq_length: int
77
+
78
+ def empty(self):
79
+ self.k_cache.zero_()
80
+ self.v_cache.zero_()
81
+
82
+ @abstractmethod
83
+ def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor, *args, **kwds) -> Tuple[Tensor, Tensor]: ...
84
+
85
+ @abstractmethod
86
+ def prefill_kv(self, k_val: Tensor, v_val: Tensor, bs: int) -> None: ...
87
+
88
+ def sync_cache(self, kv_cache: KVCacheABC):
89
+ self.k_cache.copy_(kv_cache.k_cache)
90
+ self.v_cache.copy_(kv_cache.v_cache)
91
+
92
+ def forward(self):
93
+ raise NotImplementedError()
94
+
95
+
96
+ class KVCacheNHD(KVCacheABC):
97
+ def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
98
+ super().__init__()
99
+ assert batch_size > 0
100
+ cache_shape = (batch_size, max_seq_length, n_heads, head_dim)
101
+ self.n_head = n_heads
102
+ self.head_dim = head_dim
103
+ self.batch_size = batch_size
104
+ self.max_seq_length = max_seq_length
105
+
106
+ self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
107
+ self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
108
+
109
+ def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
110
+ # input_pos: [B, ], k_val: [B, 1, H, D]
111
+
112
+ index = (
113
+ (input_pos - 1)
114
+ .unsqueeze(-1)
115
+ .unsqueeze(-1)
116
+ .unsqueeze(-1)
117
+ .expand(
118
+ -1,
119
+ -1,
120
+ self.n_head,
121
+ self.head_dim,
122
+ )
123
+ .to(torch.int64)
124
+ ) # (bs, 1, num_head, head_dim)
125
+
126
+ k_out = self.k_cache
127
+ v_out = self.v_cache
128
+ k_out.scatter_(1, index, k_val)
129
+ v_out.scatter_(1, index, v_val)
130
+
131
+ return k_out, v_out
132
+
133
+ def empty(self):
134
+ self.k_cache.zero_()
135
+ self.v_cache.zero_()
136
+
137
+ def prefill_kv(self, k_val: Tensor, v_val: Tensor, bs: int):
138
+ # input_pos: int, k_val: [B, S, H, D]
139
+
140
+ self.k_cache[[bs], : k_val.shape[1]] = k_val
141
+ self.v_cache[[bs], : v_val.shape[1]] = v_val
142
+
143
+
144
+ class KVCacheHND(KVCacheABC):
145
+ def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
146
+ super().__init__()
147
+ assert batch_size > 0
148
+ cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
149
+ self.n_head = n_heads
150
+ self.head_dim = head_dim
151
+ self.batch_size = batch_size
152
+ self.max_seq_length = max_seq_length
153
+
154
+ self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
155
+ self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
156
+
157
+ def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
158
+ # input_pos: [B, ], k_val: [B, H, 1, D]
159
+
160
+ index = (
161
+ (input_pos - 1)
162
+ .unsqueeze(-1)
163
+ .unsqueeze(-1)
164
+ .unsqueeze(-1)
165
+ .expand(
166
+ -1,
167
+ self.n_head,
168
+ -1,
169
+ self.head_dim,
170
+ )
171
+ .to(torch.int64)
172
+ ) # (bs, num_head, 1, head_dim)
173
+
174
+ k_out = self.k_cache
175
+ v_out = self.v_cache
176
+ k_out.scatter_(2, index, k_val)
177
+ v_out.scatter_(2, index, v_val)
178
+
179
+ return k_out, v_out
180
+
181
+ def empty(self):
182
+ self.k_cache.zero_()
183
+ self.v_cache.zero_()
184
+
185
+ def prefill_kv(self, k_val: Tensor, v_val: Tensor, bs: int):
186
+ # input_pos: int, k_val: [B, S, H, D]
187
+
188
+ self.k_cache[[bs], :, : k_val.shape[1]] = k_val.transpose(1, 2)
189
+ self.v_cache[[bs], :, : v_val.shape[1]] = v_val.transpose(1, 2)
190
+
191
+
192
+ class AttentionABC(ABC, nn.Module):
193
+ def __init__(self):
194
+ super().__init__()
195
+ self.n_head: int
196
+ self.hidden_dim: int
197
+ self.head_dim: int
198
+
199
+ # key, query, value projections for all heads, but in a batch
200
+ self.in_proj: nn.Linear
201
+ self.out_proj: nn.Linear
202
+
203
+ self.dropout = nn.Dropout(0.1)
204
+
205
+ self._register_load_state_dict_pre_hook(self.load_hook)
206
+
207
+ def load_hook(self, state_dict: dict, prefix, *args):
208
+ keys_to_modify = [key for key in state_dict if "in_proj_" in key]
209
+ for key in keys_to_modify:
210
+ new_key = key.replace("in_proj_", "in_proj.") # in_proj_ -> in_proj.
211
+ state_dict[new_key] = state_dict.pop(key)
212
+
213
+ @abstractmethod
214
+ def forward(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheABC, *args, **kwds) -> Tensor: ...
215
+
216
+ def prefill(self, x: Tensor, mask: Tensor, kv_cache: KVCacheABC) -> Tensor:
217
+ bsz = x.size(0)
218
+
219
+ outputs = []
220
+
221
+ for bs in range(bsz):
222
+ x_b = x[bs].unsqueeze(0)
223
+
224
+ q, k, v = self.in_proj.forward(x_b.unsqueeze(0)).chunk(3, dim=-1)
225
+
226
+ q = q.contiguous().view(1, -1, self.n_head, self.head_dim)
227
+ k = k.contiguous().view(1, -1, self.n_head, self.head_dim)
228
+ v = v.contiguous().view(1, -1, self.n_head, self.head_dim)
229
+
230
+ kv_cache.prefill_kv(k, v, bs)
231
+
232
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
233
+
234
+ attn_mask = mask[bs].unsqueeze(0).unsqueeze(0).expand(1, self.n_head, -1, -1)
235
+
236
+ attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
237
+
238
+ attn = self.dropout.forward(attn)
239
+
240
+ attn = attn.transpose(1, 2).contiguous().view(1, -1, self.hidden_dim)
241
+
242
+ output = self.out_proj.forward(attn)
243
+
244
+ outputs.append(output.squeeze(0))
245
+
246
+ return torch.nested.nested_tensor(outputs)
247
+
248
+
249
+ class FeedForward(nn.Module):
250
+ def __init__(self, dim: int, hidden_dim: int) -> None:
251
+ super().__init__()
252
+ self.linear1 = nn.Linear(dim, hidden_dim, bias=True)
253
+ self.linear2 = nn.Linear(hidden_dim, dim, bias=True)
254
+ self.dropout = nn.Dropout(0.1)
255
+
256
+ def forward(self, x: Tensor) -> Tensor:
257
+ return self.dropout.forward(self.linear2(self.dropout.forward(F.relu(self.linear1(x)))))
258
+
259
+
260
+ class TransformerBlockABC(ABC, nn.Module):
261
+ def __init__(self) -> None:
262
+ super().__init__()
263
+ self.hidden_dim: int
264
+ self.attention: AttentionABC
265
+ self.feed_forward: FeedForward
266
+ self.attention_norm: nn.LayerNorm
267
+ self.ffn_norm: nn.LayerNorm
268
+ self.dropout = nn.Dropout(0.1)
269
+
270
+ self._register_load_state_dict_pre_hook(self.load_hook)
271
+
272
+ def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
273
+ for key in list(state_dict.keys()):
274
+ new_key = (
275
+ key.replace("self_attn", "attention")
276
+ .replace("linear", "feed_forward.linear")
277
+ .replace("norm1", "attention_norm")
278
+ .replace("norm2", "ffn_norm")
279
+ )
280
+ state_dict[new_key] = state_dict.pop(key)
281
+
282
+ def forward(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheABC, *args, **kwds) -> Tensor:
283
+ h = self.attention_norm.forward(
284
+ x
285
+ + self.dropout.forward(
286
+ self.attention.forward(
287
+ x,
288
+ input_pos,
289
+ kv_cache,
290
+ *args,
291
+ **kwds,
292
+ )
293
+ )
294
+ )
295
+ out = self.ffn_norm.forward(h + self.feed_forward.forward(h))
296
+ return out
297
+
298
+ def prefill(self, x: Tensor, mask: Tensor, kv_cache: KVCacheABC) -> Tensor:
299
+ h = self.attention_norm.forward(
300
+ x
301
+ + self.dropout.forward(
302
+ self.attention.prefill(
303
+ x,
304
+ mask,
305
+ kv_cache,
306
+ )
307
+ )
308
+ )
309
+ out = self.ffn_norm.forward(h + self.feed_forward.forward(h))
310
+ return out
311
+
312
+
313
+ class TransformerDecoderABC(ABC, nn.Module):
314
+ def __init__(self) -> None:
315
+ super().__init__()
316
+
317
+ self.hidden_dim: int
318
+ self.n_head: int
319
+ self.head_dim: int
320
+ self.vocab_size: int
321
+ self.n_layer: int
322
+
323
+ self.layers: MutableSequence[TransformerBlockABC]
324
+
325
+ self.max_seq_length: int
326
+ self.max_batch_size: int
327
+
328
+ self.input_pos: Tensor
329
+ self.xy_pos: Tensor
330
+ self.xy_dec: Tensor
331
+
332
+ def forward(self, input_pos: Tensor, x: Tensor, kv_caches: MutableSequence[KVCacheABC], *args, **kwds):
333
+ for layer, kv_cache in zip(self.layers, kv_caches):
334
+ x = layer.forward(x, input_pos, kv_cache, *args, **kwds)
335
+ return x
336
+
337
+ def prefill(self, x: Tensor, mask: Tensor, kv_caches: MutableSequence[KVCacheABC]):
338
+ for layer, kv_cache in zip(self.layers, kv_caches):
339
+ x = layer.prefill(x, mask, kv_cache)
340
+ return x
341
+
342
+
343
+ class T2SDecoderABC(ABC, nn.Module):
344
+ def __init__(self) -> None:
345
+ super().__init__()
346
+
347
+ self.n_layer: int
348
+ self.hidden_dim: int
349
+ self.n_head: int
350
+
351
+ self.head_dim: int
352
+ self.embedding_dim: int
353
+ self.vocab_size: int
354
+ self.phoneme_vocab_size: int
355
+ self.p_dropout: float
356
+ self.max_seq_length: int
357
+ self.max_batch_size: int
358
+ self.EOS: int
359
+
360
+ self.bert_proj: nn.Linear
361
+ self.ar_text_embedding: TokenEmbedding
362
+ self.ar_text_position: SinePositionalEmbedding
363
+ self.ar_audio_embedding: TokenEmbedding
364
+ self.ar_audio_position: SinePositionalEmbedding
365
+ self.ar_predict_layer: nn.Linear
366
+ self.h: TransformerDecoderABC
367
+
368
+ self.kv_class: Type[KVCacheNHD] | Type[KVCacheHND]
369
+
370
+ self.GraphCache: CUDAGraphCacheABC | None
371
+
372
+ self._register_load_state_dict_pre_hook(self.load_hook)
373
+
374
+ def load_hook(self, state_dict, prefix, *args):
375
+ model_keys = [key for key in state_dict if key.startswith("model.")]
376
+ for key in model_keys:
377
+ new_key = key[len("model.") :]
378
+ state_dict[new_key] = state_dict.pop(key)
379
+
380
+ def init_cache(self, bsz: int = 0) -> MutableSequence[KVCacheABC]:
381
+ bsz = bsz or self.h.max_batch_size
382
+ assert bsz <= self.h.max_batch_size
383
+ seq_lens = self.h.max_seq_length
384
+ device = self.bert_proj.bias.device
385
+ dtype = self.bert_proj.bias.dtype
386
+ kvclass = self.kv_class
387
+ return nn.ModuleList(
388
+ [kvclass(bsz, seq_lens, self.n_head, self.head_dim) for _ in range(self.n_layer)],
389
+ ).to(device, dtype) # type: ignore
390
+
391
+ @abstractmethod
392
+ def embed(self, x: List[torch.Tensor], y: torch.Tensor, bert_features: List[Tensor]) -> Tensor: ...
393
+
394
+ def compile(self, *args, **kwds):
395
+ torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True
396
+ torch._inductor.config.coordinate_descent_tuning = True
397
+ torch._inductor.config.triton.unique_kernel_names = True
398
+ # Experimental features to reduce compilation times, will be on by default in future
399
+ torch._inductor.config.fx_graph_cache = True
400
+ torch._inductor.config.triton.cudagraph_trees = True
401
+ torch._inductor.config.triton.cudagraph_support_input_mutation = True
402
+ self.h.compile(fullgraph=True, mode="reduce-overhead")
403
+
404
+ def capture(self, input_pos: Tensor, x: Tensor, x_dec: Tensor, *args, **kwds) -> CUDAGraph:
405
+ assert torch.cuda.is_available()
406
+ s = torch.cuda.Stream()
407
+ s.wait_stream(torch.cuda.current_stream())
408
+
409
+ graph = torch.cuda.CUDAGraph()
410
+
411
+ with torch.cuda.stream(s): # type: ignore
412
+ for _ in range(5):
413
+ self.h.forward(input_pos, x, *args, **kwds)
414
+ torch.cuda.current_stream().wait_stream(s)
415
+
416
+ with torch.cuda.graph(graph):
417
+ x_dec.copy_(self.h.forward(input_pos, x, *args, **kwds))
418
+ torch.cuda.synchronize()
419
+
420
+ return graph
421
+
422
+ @abstractmethod
423
+ def pre_forward(self, session: Any) -> Tuple[List, Dict]: ...
424
+
425
+ @abstractmethod
426
+ def post_forward(self, idx: int, session: Any) -> None: ...
427
+
428
+
429
+ class CUDAGraphCacheABC(ABC):
430
+ def __init__(
431
+ self,
432
+ decoder: T2SDecoderABC,
433
+ device: torch.device = torch.device("cpu"),
434
+ dtype: torch.dtype = torch.float32,
435
+ ) -> None:
436
+ assert torch.cuda.is_available()
437
+
438
+ self.assigned: bool = False
439
+
440
+ self.decoder: T2SDecoderABC = decoder
441
+ self.kv_cache: MutableSequence[KVCacheABC] = decoder.init_cache(1)
442
+ self.xy_pos = torch.rand((1, 1, decoder.embedding_dim), device=device).to(dtype)
443
+ self.xy_dec = torch.rand((1, 1, decoder.embedding_dim), device=device).to(dtype)
444
+ self.input_pos = torch.tensor([10]).int().cuda()
445
+ self.graph: torch.cuda.CUDAGraph | None = None
446
+
447
+ self.id: int = random.randint(1, 2**32 - 1)
448
+
449
+ def assign_graph(self, session: Any):
450
+ if self.graph is None:
451
+ args, kwds = self.decoder.pre_forward(session)
452
+ graph = self.decoder.capture(
453
+ self.input_pos, self.xy_pos, self.xy_dec, kv_caches=self.kv_cache, *args, **kwds
454
+ )
455
+ self.graph = graph
456
+
457
+ if self.assigned is False:
458
+ self.get_cache_graph(session)
459
+ session.id = self.id
460
+ self.assigned = True
461
+ else:
462
+ self.capture_new_graph(session)
463
+
464
+ @abstractmethod
465
+ def release_graph(self, session: Any): ...
466
+
467
+ @abstractmethod
468
+ def get_cache_graph(self, session: Any):
469
+ pass
470
+
471
+ @abstractmethod
472
+ def capture_new_graph(self, session: Any):
473
+ pass
474
+
475
+
476
+ class TorchProfiler:
477
+ def __init__(self, debug: bool, log_dir: str = "./profiler") -> None:
478
+ self.debug = debug
479
+ self.log_dir = log_dir
480
+ self.__profiler: torch.profiler.profile
481
+
482
+ if self.debug and not os.path.exists(self.log_dir):
483
+ os.makedirs(self.log_dir)
484
+
485
+ self.tensorboard_handler = tensorboard_trace_handler(self.log_dir)
486
+
487
+ def profiler_callback(self, prof: torch.profiler.profile):
488
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))
489
+ print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=30))
490
+ self.tensorboard_handler(prof)
491
+
492
+ @staticmethod
493
+ def three_step_schedule(step: int) -> ProfilerAction:
494
+ if step == 0:
495
+ return ProfilerAction.NONE
496
+ elif step == 1:
497
+ return ProfilerAction.RECORD
498
+ elif step == 2:
499
+ return ProfilerAction.RECORD_AND_SAVE
500
+ else:
501
+ return ProfilerAction.NONE
502
+
503
+ def start(self):
504
+ if not self.debug:
505
+ return
506
+ assert self.__profiler is not None
507
+ self.__profiler.step()
508
+
509
+ def end(self):
510
+ if not self.debug:
511
+ return
512
+ assert self.__profiler is not None
513
+ self.__profiler.step()
514
+
515
+ def profiler(self):
516
+ if self.debug:
517
+ activities_list = [torch.profiler.ProfilerActivity.CPU]
518
+ if torch.cuda.is_available():
519
+ activities_list.append(torch.profiler.ProfilerActivity.CUDA)
520
+
521
+ self.__profiler = torch.profiler.profile(
522
+ activities=activities_list,
523
+ record_shapes=True,
524
+ with_stack=True,
525
+ with_modules=True,
526
+ profile_memory=True,
527
+ schedule=self.three_step_schedule,
528
+ on_trace_ready=self.profiler_callback,
529
+ )
530
+ return self.__profiler
531
+ else:
532
+ return nullcontext()
533
+
534
+ def record(self, func_name: str):
535
+ if self.debug:
536
+ return torch.profiler.record_function(func_name)
537
+ else:
538
+ return nullcontext()
AR/models/t2s_model_flash_attn.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified From https://github.com/XXXXRT666/GPT-SoVITS
3
+ """
4
+
5
+ import gc
6
+ import os
7
+ import time
8
+ import traceback
9
+ from typing import Dict, List, Tuple
10
+
11
+ import flash_attn # type: ignore
12
+ import torch
13
+ import torch.nn as nn
14
+ from tqdm import tqdm
15
+
16
+ from AR.models.embedding import (
17
+ SinePositionalEmbeddingNested as SinePositionalEmbedding,
18
+ )
19
+ from AR.models.embedding import TokenEmbedding
20
+ from AR.models.structs import T2SRequest, T2SResult, T2SSession
21
+ from AR.models.t2s_model_abc import (
22
+ AttentionABC,
23
+ CUDAGraphCacheABC,
24
+ FeedForward,
25
+ KVCacheABC,
26
+ KVCacheNHD,
27
+ T2SDecoderABC,
28
+ TorchProfiler,
29
+ TransformerBlockABC,
30
+ TransformerDecoderABC,
31
+ )
32
+
33
+ Tensor = torch.Tensor
34
+
35
+
36
+ class Attention(AttentionABC):
37
+ def __init__(self, n_head: int, hidden_dim: int):
38
+ super().__init__()
39
+ self.n_head = n_head
40
+ self.hidden_dim = hidden_dim
41
+ assert hidden_dim % n_head == 0
42
+ self.head_dim = hidden_dim // n_head
43
+
44
+ self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
45
+ self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
46
+
47
+ def forward(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheABC, *args, **kwds) -> Tensor:
48
+ bsz, seqlen, _ = x.shape
49
+
50
+ q, k, v = self.in_proj.forward(x).chunk(3, dim=-1)
51
+
52
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
53
+ k = k.view(bsz, seqlen, self.n_head, self.head_dim)
54
+ v = v.view(bsz, seqlen, self.n_head, self.head_dim)
55
+
56
+ attn: Tensor = flash_attn.flash_attn_with_kvcache(
57
+ q, kv_cache.k_cache, kv_cache.v_cache, k, v, cache_seqlens=input_pos - 1
58
+ ) # type: ignore
59
+
60
+ attn = self.dropout.forward(attn)
61
+
62
+ attn = attn.view(bsz, seqlen, self.hidden_dim)
63
+
64
+ attn = self.out_proj.forward(attn)
65
+
66
+ return attn
67
+
68
+
69
+ class TransformerBlock(TransformerBlockABC):
70
+ def __init__(self, n_head, ffn_dim, hidden_dim) -> None:
71
+ super().__init__()
72
+ self.hidden_dim = hidden_dim
73
+ self.attention = Attention(n_head, hidden_dim)
74
+ self.feed_forward = FeedForward(hidden_dim, ffn_dim)
75
+ self.attention_norm = nn.LayerNorm([self.hidden_dim])
76
+ self.ffn_norm = nn.LayerNorm([self.hidden_dim])
77
+
78
+
79
+ class TransformerDecoder(TransformerDecoderABC):
80
+ def __init__(
81
+ self,
82
+ hidden_dim,
83
+ n_layer,
84
+ n_head,
85
+ ffn_dim,
86
+ vocab_size,
87
+ max_seq_length,
88
+ max_batch_size,
89
+ ) -> None:
90
+ super().__init__()
91
+
92
+ self.hidden_dim = hidden_dim
93
+ self.n_head = n_head
94
+ assert hidden_dim % n_head == 0
95
+
96
+ self.head_dim = hidden_dim // n_head
97
+ self.vocab_size = vocab_size
98
+
99
+ self.n_layer = n_layer
100
+
101
+ self.layers = nn.ModuleList( # type: ignore
102
+ TransformerBlock(n_head, ffn_dim, hidden_dim) for _ in range(n_layer)
103
+ )
104
+
105
+ self.max_seq_length: int = max_seq_length
106
+ self.max_batch_size: int = max_batch_size
107
+
108
+ self.setup_caches(self.max_batch_size, self.max_seq_length)
109
+
110
+ def setup_caches(self, max_batch_size=10, max_seq_length=2500):
111
+ self.max_seq_length = max_seq_length
112
+ self.max_batch_size = max_batch_size
113
+
114
+
115
+ class T2SDecoder(T2SDecoderABC):
116
+ def __init__(
117
+ self,
118
+ config,
119
+ *args,
120
+ norm_first=False,
121
+ max_seq_length=2500,
122
+ max_batch_size=10,
123
+ **kwds,
124
+ ) -> None:
125
+ assert torch.cuda.is_available()
126
+ super().__init__()
127
+
128
+ hidden_dim = config["model"]["hidden_dim"]
129
+ embedding_dim = config["model"]["embedding_dim"]
130
+ n_head = config["model"]["head"]
131
+ n_layer = config["model"]["n_layer"]
132
+ vocab_size = config["model"]["vocab_size"]
133
+ phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
134
+ p_dropout = config["model"]["dropout"]
135
+ EOS = config["model"]["EOS"]
136
+ ffn_dim = hidden_dim * 4
137
+ self.norm_first = norm_first
138
+
139
+ self.n_layer = n_layer
140
+ self.hidden_dim = hidden_dim
141
+ self.n_head = n_head
142
+ assert hidden_dim % n_head == 0
143
+
144
+ self.head_dim = hidden_dim // n_head
145
+ self.embedding_dim = embedding_dim
146
+ self.vocab_size = vocab_size
147
+ self.phoneme_vocab_size = phoneme_vocab_size
148
+ self.p_dropout = p_dropout
149
+ self.max_seq_length = max_seq_length
150
+ self.max_batch_size = max_batch_size
151
+ self.EOS = EOS
152
+ assert self.EOS == self.vocab_size - 1
153
+
154
+ self.bert_proj = nn.Linear(1024, self.embedding_dim)
155
+ self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size, self.p_dropout)
156
+ self.ar_text_position = SinePositionalEmbedding(
157
+ self.embedding_dim,
158
+ dropout=0.1,
159
+ scale=False,
160
+ alpha=True,
161
+ max_batch_size=max_batch_size,
162
+ max_seq_len=max_seq_length,
163
+ )
164
+ self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size, self.p_dropout)
165
+ self.ar_audio_position = SinePositionalEmbedding(
166
+ self.embedding_dim,
167
+ dropout=0.1,
168
+ scale=False,
169
+ alpha=True,
170
+ max_batch_size=max_batch_size,
171
+ max_seq_len=max_seq_length,
172
+ )
173
+ self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
174
+ self.h: TransformerDecoderABC = TransformerDecoder(
175
+ hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size
176
+ )
177
+
178
+ self.kv_class = KVCacheNHD
179
+ self._register_load_state_dict_pre_hook(self.load_hook)
180
+
181
+ def embed(
182
+ self,
183
+ x: List[torch.Tensor],
184
+ y: torch.Tensor,
185
+ bert_features: List[torch.Tensor],
186
+ ):
187
+ x_nested = torch.nested.nested_tensor(x)
188
+ assert x_nested.size(0) <= self.max_batch_size
189
+ bert_features_nested = torch.nested.nested_tensor(list(map(lambda x: x.transpose(0, 1), bert_features)))
190
+
191
+ x_emb = self.ar_text_embedding.forward(x_nested)
192
+ bert = self.bert_proj.forward(bert_features_nested)
193
+ x_emb = x_emb + bert
194
+ x_pos = self.ar_text_position.prefill(x_emb)
195
+
196
+ y_nested = torch.nested.nested_tensor(list(y.unbind(0)))
197
+ y_emb = self.ar_audio_embedding.forward(y_nested)
198
+ y_pos = self.ar_audio_position.prefill(y_emb)
199
+
200
+ xy_pos = torch.nested.nested_tensor([torch.cat([x_pos[i], y_pos[i]]) for i in range(len(x))])
201
+ return xy_pos
202
+
203
+ def post_forward(self, idx: int, session: T2SSession) -> None:
204
+ pass
205
+
206
+ def pre_forward(self, session: T2SSession) -> Tuple[List, Dict]:
207
+ return list(), dict()
208
+
209
+
210
+ class CUDAGraphCache(CUDAGraphCacheABC):
211
+ def __init__(
212
+ self,
213
+ decoder: T2SDecoderABC,
214
+ device: torch.device = torch.device("cpu"),
215
+ dtype: torch.dtype = torch.float32,
216
+ ) -> None:
217
+ super().__init__(decoder, device, dtype)
218
+
219
+ def release_graph(self, session: T2SSession):
220
+ if session.id != self.id:
221
+ self.assigned = False
222
+ else:
223
+ del session.graph, session.xy_pos_, session.xy_dec_, session.input_pos, session.kv_cache
224
+
225
+ def get_cache_graph(self, session: T2SSession):
226
+ assert self.graph
227
+ session.graph = self.graph
228
+
229
+ session.xy_pos_ = self.xy_pos
230
+ session.xy_dec_ = self.xy_dec
231
+ session.input_pos = self.input_pos.copy_(session.input_pos)
232
+
233
+ for cache, cache_ in zip(self.kv_cache, session.kv_cache):
234
+ cache.sync_cache(cache_)
235
+
236
+ def capture_new_graph(self, session: T2SSession):
237
+ session.xy_pos_ = self.xy_pos.clone()
238
+ session.xy_dec_ = self.xy_dec.clone()
239
+ session.input_pos = self.input_pos.clone().copy_(session.input_pos)
240
+
241
+ args, kwds = self.decoder.pre_forward(session)
242
+ graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, kv_caches=self.kv_cache, *args, **kwds)
243
+ session.graph = graph
244
+
245
+
246
+ class CUDAGraphRunner:
247
+ def __init__(
248
+ self,
249
+ decoder_model: T2SDecoderABC,
250
+ device: torch.device = torch.device("cpu"),
251
+ dtype: torch.dtype = torch.float32,
252
+ ) -> None:
253
+ assert device.type == "cuda"
254
+ self.device = device
255
+ self.dtype = dtype
256
+
257
+ self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
258
+
259
+ self.graphcache = CUDAGraphCache(decoder_model, device, dtype)
260
+
261
+ def _handle_request(self, request: T2SRequest):
262
+ with self.device:
263
+ decoder = self.decoder_model
264
+ session = T2SSession(decoder, request, device=self.device, dtype=self.dtype)
265
+
266
+ t1 = 0.0
267
+ infer_speed = 0.0
268
+
269
+ torch_profiler = TorchProfiler(request.debug)
270
+ with torch_profiler.profiler():
271
+ for idx in tqdm(range(1500)):
272
+ if idx == 0:
273
+ session.kv_cache = decoder.init_cache(session.bsz)
274
+ xy_dec = decoder.h.prefill(session.xy_pos, session.attn_mask_nested, session.kv_cache)
275
+ xy_dec = torch.stack([t[[-1]] for t in xy_dec.unbind()])
276
+ else:
277
+ if request.use_cuda_graph and session.graph is None and torch.cuda.is_available():
278
+ self.graphcache.assign_graph(session)
279
+
280
+ with torch_profiler.record("AR"):
281
+ if session.graph:
282
+ session.xy_pos_.copy_(session.xy_pos)
283
+ session.graph.replay()
284
+ xy_dec = session.xy_dec_.clone()
285
+ else:
286
+ args, kwds = decoder.pre_forward(session)
287
+ xy_dec = decoder.h.forward(
288
+ session.input_pos,
289
+ session.xy_pos,
290
+ session.kv_cache,
291
+ *args,
292
+ **kwds,
293
+ )
294
+
295
+ decoder.post_forward(idx, session)
296
+ logits = decoder.ar_predict_layer(xy_dec[:, -1])
297
+ session.input_pos.add_(1)
298
+
299
+ if idx == 0:
300
+ logits[:, -1] = float("-inf")
301
+
302
+ with torch_profiler.record("Sampling"):
303
+ samples = session.sampler.sample(
304
+ logits=logits,
305
+ previous_tokens=session.y,
306
+ top_k=request.top_k,
307
+ top_p=request.top_p,
308
+ repetition_penalty=request.repetition_penalty,
309
+ temperature=request.temperature,
310
+ )
311
+
312
+ session.y = torch.cat([session.y, samples], dim=1)
313
+
314
+ with torch_profiler.record("EOS"):
315
+ argmax_token = torch.argmax(logits, dim=-1)
316
+ sample_token = samples.squeeze(1)
317
+ EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
318
+
319
+ newly_done_mask = EOS_mask & (~session.completed)
320
+ newly_done_indices = newly_done_mask.nonzero()
321
+
322
+ if newly_done_indices.numel() > 0:
323
+ session.y_results[newly_done_indices[0]] = session.y[
324
+ newly_done_indices[0], session.y_len : -1
325
+ ].squeeze(0)
326
+ session.completed[newly_done_indices] = True
327
+
328
+ if torch.all(session.completed).item():
329
+ if session.y.size(1) == 0:
330
+ session.y = torch.cat([session.y, torch.zeros_like(samples)], dim=1)
331
+ tqdm.write("Bad Zero Prediction")
332
+ else:
333
+ tqdm.write(
334
+ f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> \n{[i.size(0) for i in session.y_results].__str__().strip('[]')}"
335
+ )
336
+ tqdm.write(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
337
+ infer_speed = (idx - 1) / (time.perf_counter() - t1)
338
+ break
339
+
340
+ if (
341
+ request.early_stop_num != -1
342
+ and (session.y.size(1) - session.y_len) > request.early_stop_num
343
+ ) or idx == 1499:
344
+ for i in range(session.bsz):
345
+ if not session.completed[i].item():
346
+ session.y_results[i] = session.y[i, session.y_len :]
347
+ session.completed[i] = True
348
+ break
349
+
350
+ with torch_profiler.record("NextPos"):
351
+ y_emb = decoder.ar_audio_embedding(session.y[:, -1:])
352
+ session.xy_pos = decoder.ar_audio_position.forward(session.input_pos - session.x_lens, y_emb)
353
+
354
+ if idx == 2:
355
+ torch_profiler.start()
356
+ t1 = time.perf_counter()
357
+
358
+ if idx == 51:
359
+ torch_profiler.end()
360
+
361
+ if idx % 100 == 0:
362
+ match session.device.type:
363
+ case "cuda":
364
+ torch.cuda.empty_cache()
365
+ case "mps":
366
+ torch.mps.empty_cache()
367
+ case "xpu":
368
+ torch.xpu.empty_cache()
369
+ case "mtia":
370
+ torch.mtia.empty_cache()
371
+
372
+ match session.device.type:
373
+ case "cuda":
374
+ torch.cuda.empty_cache()
375
+ case "mps":
376
+ torch.mps.empty_cache()
377
+ case "xpu":
378
+ torch.xpu.empty_cache()
379
+ case "mtia":
380
+ torch.mtia.empty_cache()
381
+ case "cpu":
382
+ gc.collect()
383
+
384
+ torch_profiler.end()
385
+ self.graphcache.release_graph(session)
386
+ return session.y_results[: request.valid_length], infer_speed
387
+
388
+ def generate(self, request: T2SRequest):
389
+ try:
390
+ result, infer_speed = self._handle_request(request)
391
+ t2s_result = T2SResult(result=result, infer_speed=infer_speed, status="Success")
392
+ except Exception as e:
393
+ t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
394
+ return t2s_result
395
+
396
+ @staticmethod
397
+ def load_decoder(weights_path: os.PathLike, implement: str = "flash_attn"):
398
+ print(f"Loading Text2Semantic Weights from {weights_path} with {implement.replace('_', ' ').title()} Implement")
399
+ module_path = f"AR.models.t2s_model_{implement.lower()}"
400
+ cls_name = "T2SDecoder"
401
+ mod = __import__(module_path, fromlist=[cls_name])
402
+ decoder_cls: T2SDecoderABC = getattr(mod, cls_name)
403
+ dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=False, mmap=True)
404
+ config = dict_s1["config"]
405
+ decoder: T2SDecoderABC = decoder_cls(config, max_batch_size=1)
406
+ state_dict = dict_s1["weight"]
407
+ decoder.load_state_dict(state_dict)
408
+ return decoder.eval()
configs/s1.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ seed: 1234
3
+ epochs: 300
4
+ batch_size: 8
5
+ gradient_accumulation: 4
6
+ save_every_n_epoch: 1
7
+ precision: 16
8
+ gradient_clip: 1.0
9
+ optimizer:
10
+ lr: 0.01
11
+ lr_init: 0.00001
12
+ lr_end: 0.0001
13
+ warmup_steps: 2000
14
+ decay_steps: 40000
15
+ data:
16
+ max_eval_sample: 8
17
+ max_sec: 54
18
+ num_workers: 1
19
+ pad_val: 1024 # same with EOS in model
20
+ model:
21
+ vocab_size: 1025
22
+ phoneme_vocab_size: 512
23
+ embedding_dim: 512
24
+ hidden_dim: 512
25
+ head: 16
26
+ linear_units: 2048
27
+ n_layer: 12
28
+ dropout: 0
29
+ EOS: 1024
30
+ inference:
31
+ top_k: 5
configs/s1big.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ seed: 1234
3
+ epochs: 300
4
+ batch_size: 8
5
+ gradient_accumulation: 4
6
+ save_every_n_epoch: 1
7
+ precision: 16-mixed
8
+ gradient_clip: 1.0
9
+ optimizer:
10
+ lr: 0.01
11
+ lr_init: 0.00001
12
+ lr_end: 0.0001
13
+ warmup_steps: 2000
14
+ decay_steps: 40000
15
+ data:
16
+ max_eval_sample: 8
17
+ max_sec: 54
18
+ num_workers: 1
19
+ pad_val: 1024 # same with EOS in model
20
+ model:
21
+ vocab_size: 1025
22
+ phoneme_vocab_size: 512
23
+ embedding_dim: 1024
24
+ hidden_dim: 1024
25
+ head: 16
26
+ linear_units: 2048
27
+ n_layer: 16
28
+ dropout: 0
29
+ EOS: 1024
30
+ inference:
31
+ top_k: 5
configs/s1big2.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ seed: 1234
3
+ epochs: 300
4
+ batch_size: 12
5
+ gradient_accumulation: 4
6
+ save_every_n_epoch: 1
7
+ precision: 16-mixed
8
+ gradient_clip: 1.0
9
+ optimizer:
10
+ lr: 0.01
11
+ lr_init: 0.00001
12
+ lr_end: 0.0001
13
+ warmup_steps: 2000
14
+ decay_steps: 40000
15
+ data:
16
+ max_eval_sample: 8
17
+ max_sec: 54
18
+ num_workers: 1
19
+ pad_val: 1024 # same with EOS in model
20
+ model:
21
+ vocab_size: 1025
22
+ phoneme_vocab_size: 512
23
+ embedding_dim: 1024
24
+ hidden_dim: 1024
25
+ head: 16
26
+ linear_units: 2048
27
+ n_layer: 6
28
+ dropout: 0
29
+ EOS: 1024
30
+ inference:
31
+ top_k: 5
configs/s1longer-v2.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ seed: 1234
3
+ epochs: 20
4
+ batch_size: 8
5
+ save_every_n_epoch: 1
6
+ precision: 16-mixed
7
+ gradient_clip: 1.0
8
+ optimizer:
9
+ lr: 0.01
10
+ lr_init: 0.00001
11
+ lr_end: 0.0001
12
+ warmup_steps: 2000
13
+ decay_steps: 40000
14
+ data:
15
+ max_eval_sample: 8
16
+ max_sec: 54
17
+ num_workers: 4
18
+ pad_val: 1024 # same with EOS in model
19
+ model:
20
+ vocab_size: 1025
21
+ phoneme_vocab_size: 732
22
+ embedding_dim: 512
23
+ hidden_dim: 512
24
+ head: 16
25
+ linear_units: 2048
26
+ n_layer: 24
27
+ dropout: 0
28
+ EOS: 1024
29
+ random_bert: 0
30
+ inference:
31
+ top_k: 15
configs/s1longer.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ seed: 1234
3
+ epochs: 20
4
+ batch_size: 8
5
+ save_every_n_epoch: 1
6
+ precision: 16-mixed
7
+ gradient_clip: 1.0
8
+ optimizer:
9
+ lr: 0.01
10
+ lr_init: 0.00001
11
+ lr_end: 0.0001
12
+ warmup_steps: 2000
13
+ decay_steps: 40000
14
+ data:
15
+ max_eval_sample: 8
16
+ max_sec: 54
17
+ num_workers: 4
18
+ pad_val: 1024 # same with EOS in model
19
+ model:
20
+ vocab_size: 1025
21
+ phoneme_vocab_size: 512
22
+ embedding_dim: 512
23
+ hidden_dim: 512
24
+ head: 16
25
+ linear_units: 2048
26
+ n_layer: 24
27
+ dropout: 0
28
+ EOS: 1024
29
+ random_bert: 0
30
+ inference:
31
+ top_k: 5
configs/s1mq.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ seed: 1234
3
+ epochs: 100
4
+ batch_size: 6
5
+ gradient_accumulation: 4
6
+ save_every_n_epoch: 1
7
+ precision: 32
8
+ gradient_clip: 1.0
9
+ optimizer:
10
+ lr: 0.01
11
+ lr_init: 0.00001
12
+ lr_end: 0.0001
13
+ warmup_steps: 2000
14
+ decay_steps: 40000
15
+ data:
16
+ max_eval_sample: 8
17
+ max_sec: 40
18
+ num_workers: 1
19
+ pad_val: 1024 # same with EOS in model
20
+ model:
21
+ saving_path: "ckpt/"
22
+ resume_checkpoint: null
23
+ vocoder_config_path: "quantizer/new_ckpt/config.json"
24
+ vocoder_ckpt_path: "quantizer/new_ckpt/g_00600000"
25
+ datadir: "/home/liweiche/GigaSpeech/wavs"
26
+ metapath: "/home/liweiche/GigaSpeech/train2.json"
27
+ val_metapath: "/home/liweiche/GigaSpeech/dev2.json"
28
+ sampledir: "logs/"
29
+ pretrained_path: null
30
+ lr: 0.0001
31
+ batch_size: 200.0
32
+ train_bucket_size: 8192
33
+ training_step: 800000
34
+ optim_flat_percent: 0.0
35
+ warmup_step: 50
36
+ adam_beta1: 0.9
37
+ adam_beta2: 0.98
38
+ ffd_size: 3072
39
+ hidden_size: 768
40
+ enc_nlayers: 6
41
+ dec_nlayers: 6
42
+ nheads: 12
43
+ ar_layer: 4
44
+ ar_ffd_size: 1024
45
+ ar_hidden_size: 256
46
+ ar_nheads: 4
47
+ aligner_softmax_temp: 1.0
48
+ layer_norm_eps: 0.00001
49
+ speaker_embed_dropout: 0.05
50
+ label_smoothing: 0.0
51
+ val_check_interval: 5000
52
+ check_val_every_n_epoch: 1
53
+ precision: "fp16"
54
+ nworkers: 16
55
+ distributed: true
56
+ accelerator: "ddp"
57
+ version: null
58
+ accumulate_grad_batches: 1
59
+ use_repetition_token: true
60
+ use_repetition_gating: false
61
+ repetition_penalty: 1.0
62
+ sampling_temperature: 1.0
63
+ top_k: -1
64
+ min_top_k: 3
65
+ top_p: 0.8
66
+ sample_num: 4
67
+ length_penalty_max_length: 15000
68
+ length_penalty_max_prob: 0.95
69
+ max_input_length: 2048
70
+ max_output_length: 2000
71
+ sample_rate: 16000
72
+ n_codes: 1024
73
+ n_cluster_groups: 1
74
+ phone_context_window: 4
75
+ phoneset_size: 1000
76
+ inference:
77
+ top_k: 5
configs/s2.json ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 100,
4
+ "eval_interval": 500,
5
+ "seed": 1234,
6
+ "epochs": 100,
7
+ "learning_rate": 0.0001,
8
+ "betas": [
9
+ 0.8,
10
+ 0.99
11
+ ],
12
+ "eps": 1e-09,
13
+ "batch_size": 32,
14
+ "fp16_run": true,
15
+ "lr_decay": 0.999875,
16
+ "segment_size": 20480,
17
+ "init_lr_ratio": 1,
18
+ "warmup_epochs": 0,
19
+ "c_mel": 45,
20
+ "c_kl": 1.0,
21
+ "text_low_lr_rate": 0.4
22
+ },
23
+ "data": {
24
+ "max_wav_value": 32768.0,
25
+ "sampling_rate": 32000,
26
+ "filter_length": 2048,
27
+ "hop_length": 640,
28
+ "win_length": 2048,
29
+ "n_mel_channels": 128,
30
+ "mel_fmin": 0.0,
31
+ "mel_fmax": null,
32
+ "add_blank": true,
33
+ "n_speakers": 300,
34
+ "cleaned_text": true
35
+ },
36
+ "model": {
37
+ "inter_channels": 192,
38
+ "hidden_channels": 192,
39
+ "filter_channels": 768,
40
+ "n_heads": 2,
41
+ "n_layers": 6,
42
+ "kernel_size": 3,
43
+ "p_dropout": 0.1,
44
+ "resblock": "1",
45
+ "resblock_kernel_sizes": [
46
+ 3,
47
+ 7,
48
+ 11
49
+ ],
50
+ "resblock_dilation_sizes": [
51
+ [
52
+ 1,
53
+ 3,
54
+ 5
55
+ ],
56
+ [
57
+ 1,
58
+ 3,
59
+ 5
60
+ ],
61
+ [
62
+ 1,
63
+ 3,
64
+ 5
65
+ ]
66
+ ],
67
+ "upsample_rates": [
68
+ 10,
69
+ 8,
70
+ 2,
71
+ 2,
72
+ 2
73
+ ],
74
+ "upsample_initial_channel": 512,
75
+ "upsample_kernel_sizes": [
76
+ 16,
77
+ 16,
78
+ 8,
79
+ 2,
80
+ 2
81
+ ],
82
+ "n_layers_q": 3,
83
+ "use_spectral_norm": false,
84
+ "gin_channels": 512,
85
+ "semantic_frame_rate": "25hz",
86
+ "freeze_quantizer": true
87
+ },
88
+ "s2_ckpt_dir": "logs/s2/big2k1",
89
+ "content_module": "cnhubert"
90
+ }
configs/train.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gpu:
2
+ n_card: 1
3
+ n_process_per_card: 2
4
+ io:
5
+ text_path: D:\RVC1006\GPT-SoVITS\GPT_SoVITS
6
+ save_every_n_epoch: 1
7
+ precision: 16-mixed
8
+ gradient_clip: 1.0
9
+ optimizer:
10
+ lr: 0.01
11
+ lr_init: 0.00001
12
+ lr_end: 0.0001
13
+ warmup_steps: 2000
14
+ decay_steps: 40000
15
+ data:
16
+ max_eval_sample: 8
17
+ max_sec: 54
18
+ num_workers: 1
19
+ pad_val: 1024 # same with EOS in model
20
+ model:
21
+ vocab_size: 1025
22
+ phoneme_vocab_size: 512
23
+ embedding_dim: 512
24
+ hidden_dim: 512
25
+ head: 16
26
+ linear_units: 2048
27
+ n_layer: 24
28
+ dropout: 0
29
+ EOS: 1024
30
+ random_bert: 0
31
+ inference:
32
+ top_k: 5
download.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import os, sys
2
+ now_dir = os.getcwd()
3
+ sys.path.insert(0, now_dir)
4
+ from .text.g2pw import G2PWPinyin
5
+ g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",v_to_u=False, neutral_tone_with_five=True)
eres2net/ERes2Net.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ """
5
+ Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
6
+ ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
7
+ The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
8
+ The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
9
+ """
10
+
11
+
12
+ import torch
13
+ import math
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import pooling_layers as pooling_layers
17
+ from fusion import AFF
18
+
19
+ class ReLU(nn.Hardtanh):
20
+
21
+ def __init__(self, inplace=False):
22
+ super(ReLU, self).__init__(0, 20, inplace)
23
+
24
+ def __repr__(self):
25
+ inplace_str = 'inplace' if self.inplace else ''
26
+ return self.__class__.__name__ + ' (' \
27
+ + inplace_str + ')'
28
+
29
+
30
+ class BasicBlockERes2Net(nn.Module):
31
+ expansion = 2
32
+
33
+ def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
34
+ super(BasicBlockERes2Net, self).__init__()
35
+ width = int(math.floor(planes*(baseWidth/64.0)))
36
+ self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
37
+ self.bn1 = nn.BatchNorm2d(width*scale)
38
+ self.nums = scale
39
+
40
+ convs=[]
41
+ bns=[]
42
+ for i in range(self.nums):
43
+ convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
44
+ bns.append(nn.BatchNorm2d(width))
45
+ self.convs = nn.ModuleList(convs)
46
+ self.bns = nn.ModuleList(bns)
47
+ self.relu = ReLU(inplace=True)
48
+
49
+ self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
50
+ self.bn3 = nn.BatchNorm2d(planes*self.expansion)
51
+ self.shortcut = nn.Sequential()
52
+ if stride != 1 or in_planes != self.expansion * planes:
53
+ self.shortcut = nn.Sequential(
54
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
55
+ stride=stride, bias=False),
56
+ nn.BatchNorm2d(self.expansion * planes))
57
+ self.stride = stride
58
+ self.width = width
59
+ self.scale = scale
60
+
61
+ def forward(self, x):
62
+ residual = x
63
+
64
+ out = self.conv1(x)
65
+ out = self.bn1(out)
66
+ out = self.relu(out)
67
+ spx = torch.split(out,self.width,1)
68
+ for i in range(self.nums):
69
+ if i==0:
70
+ sp = spx[i]
71
+ else:
72
+ sp = sp + spx[i]
73
+ sp = self.convs[i](sp)
74
+ sp = self.relu(self.bns[i](sp))
75
+ if i==0:
76
+ out = sp
77
+ else:
78
+ out = torch.cat((out,sp),1)
79
+
80
+ out = self.conv3(out)
81
+ out = self.bn3(out)
82
+
83
+ residual = self.shortcut(x)
84
+ out += residual
85
+ out = self.relu(out)
86
+
87
+ return out
88
+
89
+ class BasicBlockERes2Net_diff_AFF(nn.Module):
90
+ expansion = 2
91
+
92
+ def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
93
+ super(BasicBlockERes2Net_diff_AFF, self).__init__()
94
+ width = int(math.floor(planes*(baseWidth/64.0)))
95
+ self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
96
+ self.bn1 = nn.BatchNorm2d(width*scale)
97
+ self.nums = scale
98
+
99
+ convs=[]
100
+ fuse_models=[]
101
+ bns=[]
102
+ for i in range(self.nums):
103
+ convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
104
+ bns.append(nn.BatchNorm2d(width))
105
+ for j in range(self.nums - 1):
106
+ fuse_models.append(AFF(channels=width))
107
+
108
+ self.convs = nn.ModuleList(convs)
109
+ self.bns = nn.ModuleList(bns)
110
+ self.fuse_models = nn.ModuleList(fuse_models)
111
+ self.relu = ReLU(inplace=True)
112
+
113
+ self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
114
+ self.bn3 = nn.BatchNorm2d(planes*self.expansion)
115
+ self.shortcut = nn.Sequential()
116
+ if stride != 1 or in_planes != self.expansion * planes:
117
+ self.shortcut = nn.Sequential(
118
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
119
+ stride=stride, bias=False),
120
+ nn.BatchNorm2d(self.expansion * planes))
121
+ self.stride = stride
122
+ self.width = width
123
+ self.scale = scale
124
+
125
+ def forward(self, x):
126
+ residual = x
127
+
128
+ out = self.conv1(x)
129
+ out = self.bn1(out)
130
+ out = self.relu(out)
131
+ spx = torch.split(out,self.width,1)
132
+ for i in range(self.nums):
133
+ if i==0:
134
+ sp = spx[i]
135
+ else:
136
+ sp = self.fuse_models[i-1](sp, spx[i])
137
+
138
+ sp = self.convs[i](sp)
139
+ sp = self.relu(self.bns[i](sp))
140
+ if i==0:
141
+ out = sp
142
+ else:
143
+ out = torch.cat((out,sp),1)
144
+
145
+ out = self.conv3(out)
146
+ out = self.bn3(out)
147
+
148
+ residual = self.shortcut(x)
149
+ out += residual
150
+ out = self.relu(out)
151
+
152
+ return out
153
+
154
+ class ERes2Net(nn.Module):
155
+ def __init__(self,
156
+ block=BasicBlockERes2Net,
157
+ block_fuse=BasicBlockERes2Net_diff_AFF,
158
+ num_blocks=[3, 4, 6, 3],
159
+ m_channels=32,
160
+ feat_dim=80,
161
+ embedding_size=192,
162
+ pooling_func='TSTP',
163
+ two_emb_layer=False):
164
+ super(ERes2Net, self).__init__()
165
+ self.in_planes = m_channels
166
+ self.feat_dim = feat_dim
167
+ self.embedding_size = embedding_size
168
+ self.stats_dim = int(feat_dim / 8) * m_channels * 8
169
+ self.two_emb_layer = two_emb_layer
170
+
171
+ self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
172
+ self.bn1 = nn.BatchNorm2d(m_channels)
173
+ self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
174
+ self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
175
+ self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
176
+ self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
177
+
178
+ # Downsampling module for each layer
179
+ self.layer1_downsample = nn.Conv2d(m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1, bias=False)
180
+ self.layer2_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False)
181
+ self.layer3_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False)
182
+
183
+ # Bottom-up fusion module
184
+ self.fuse_mode12 = AFF(channels=m_channels * 4)
185
+ self.fuse_mode123 = AFF(channels=m_channels * 8)
186
+ self.fuse_mode1234 = AFF(channels=m_channels * 16)
187
+
188
+ self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
189
+ self.pool = getattr(pooling_layers, pooling_func)(
190
+ in_dim=self.stats_dim * block.expansion)
191
+ self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
192
+ embedding_size)
193
+ if self.two_emb_layer:
194
+ self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
195
+ self.seg_2 = nn.Linear(embedding_size, embedding_size)
196
+ else:
197
+ self.seg_bn_1 = nn.Identity()
198
+ self.seg_2 = nn.Identity()
199
+
200
+ def _make_layer(self, block, planes, num_blocks, stride):
201
+ strides = [stride] + [1] * (num_blocks - 1)
202
+ layers = []
203
+ for stride in strides:
204
+ layers.append(block(self.in_planes, planes, stride))
205
+ self.in_planes = planes * block.expansion
206
+ return nn.Sequential(*layers)
207
+
208
+ def forward(self, x):
209
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
210
+ x = x.unsqueeze_(1)
211
+ out = F.relu(self.bn1(self.conv1(x)))
212
+ out1 = self.layer1(out)
213
+ out2 = self.layer2(out1)
214
+ out1_downsample = self.layer1_downsample(out1)
215
+ fuse_out12 = self.fuse_mode12(out2, out1_downsample)
216
+ out3 = self.layer3(out2)
217
+ fuse_out12_downsample = self.layer2_downsample(fuse_out12)
218
+ fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
219
+ out4 = self.layer4(out3)
220
+ fuse_out123_downsample = self.layer3_downsample(fuse_out123)
221
+ fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
222
+ stats = self.pool(fuse_out1234)
223
+
224
+ embed_a = self.seg_1(stats)
225
+ if self.two_emb_layer:
226
+ out = F.relu(embed_a)
227
+ out = self.seg_bn_1(out)
228
+ embed_b = self.seg_2(out)
229
+ return embed_b
230
+ else:
231
+ return embed_a
232
+
233
+ def forward3(self, x):
234
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
235
+ x = x.unsqueeze_(1)
236
+ out = F.relu(self.bn1(self.conv1(x)))
237
+ out1 = self.layer1(out)
238
+ out2 = self.layer2(out1)
239
+ out1_downsample = self.layer1_downsample(out1)
240
+ fuse_out12 = self.fuse_mode12(out2, out1_downsample)
241
+ out3 = self.layer3(out2)
242
+ fuse_out12_downsample = self.layer2_downsample(fuse_out12)
243
+ fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
244
+ out4 = self.layer4(out3)
245
+ fuse_out123_downsample = self.layer3_downsample(fuse_out123)
246
+ fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1,end_dim=2).mean(-1)
247
+ return fuse_out1234
248
+
249
+
250
+ if __name__ == '__main__':
251
+
252
+ x = torch.zeros(10, 300, 80)
253
+ model = ERes2Net(feat_dim=80, embedding_size=192, pooling_func='TSTP')
254
+ model.eval()
255
+ out = model(x)
256
+ print(out.shape) # torch.Size([10, 192])
257
+
258
+ num_params = sum(param.numel() for param in model.parameters())
259
+ print("{} M".format(num_params / 1e6)) # 6.61M
260
+
eres2net/ERes2NetV2.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ """
5
+ To further improve the short-duration feature extraction capability of ERes2Net, we expand the channel dimension
6
+ within each stage. However, this modification also increases the number of model parameters and computational complexity.
7
+ To alleviate this problem, we propose an improved ERes2NetV2 by pruning redundant structures, ultimately reducing
8
+ both the model parameters and its computational cost.
9
+ """
10
+
11
+
12
+
13
+ import torch
14
+ import math
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import pooling_layers as pooling_layers
18
+ from fusion import AFF
19
+
20
+ class ReLU(nn.Hardtanh):
21
+
22
+ def __init__(self, inplace=False):
23
+ super(ReLU, self).__init__(0, 20, inplace)
24
+
25
+ def __repr__(self):
26
+ inplace_str = 'inplace' if self.inplace else ''
27
+ return self.__class__.__name__ + ' (' \
28
+ + inplace_str + ')'
29
+
30
+
31
+ class BasicBlockERes2NetV2(nn.Module):
32
+
33
+ def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
34
+ super(BasicBlockERes2NetV2, self).__init__()
35
+ width = int(math.floor(planes*(baseWidth/64.0)))
36
+ self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
37
+ self.bn1 = nn.BatchNorm2d(width*scale)
38
+ self.nums = scale
39
+ self.expansion = expansion
40
+
41
+ convs=[]
42
+ bns=[]
43
+ for i in range(self.nums):
44
+ convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
45
+ bns.append(nn.BatchNorm2d(width))
46
+ self.convs = nn.ModuleList(convs)
47
+ self.bns = nn.ModuleList(bns)
48
+ self.relu = ReLU(inplace=True)
49
+
50
+ self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
51
+ self.bn3 = nn.BatchNorm2d(planes*self.expansion)
52
+ self.shortcut = nn.Sequential()
53
+ if stride != 1 or in_planes != self.expansion * planes:
54
+ self.shortcut = nn.Sequential(
55
+ nn.Conv2d(in_planes,
56
+ self.expansion * planes,
57
+ kernel_size=1,
58
+ stride=stride,
59
+ bias=False),
60
+ nn.BatchNorm2d(self.expansion * planes))
61
+ self.stride = stride
62
+ self.width = width
63
+ self.scale = scale
64
+
65
+ def forward(self, x):
66
+ residual = x
67
+
68
+ out = self.conv1(x)
69
+ out = self.bn1(out)
70
+ out = self.relu(out)
71
+ spx = torch.split(out,self.width,1)
72
+ for i in range(self.nums):
73
+ if i==0:
74
+ sp = spx[i]
75
+ else:
76
+ sp = sp + spx[i]
77
+ sp = self.convs[i](sp)
78
+ sp = self.relu(self.bns[i](sp))
79
+ if i==0:
80
+ out = sp
81
+ else:
82
+ out = torch.cat((out,sp),1)
83
+
84
+ out = self.conv3(out)
85
+ out = self.bn3(out)
86
+
87
+ residual = self.shortcut(x)
88
+ out += residual
89
+ out = self.relu(out)
90
+
91
+ return out
92
+
93
+ class BasicBlockERes2NetV2AFF(nn.Module):
94
+
95
+ def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
96
+ super(BasicBlockERes2NetV2AFF, self).__init__()
97
+ width = int(math.floor(planes*(baseWidth/64.0)))
98
+ self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
99
+ self.bn1 = nn.BatchNorm2d(width*scale)
100
+ self.nums = scale
101
+ self.expansion = expansion
102
+
103
+ convs=[]
104
+ fuse_models=[]
105
+ bns=[]
106
+ for i in range(self.nums):
107
+ convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
108
+ bns.append(nn.BatchNorm2d(width))
109
+ for j in range(self.nums - 1):
110
+ fuse_models.append(AFF(channels=width, r=4))
111
+
112
+ self.convs = nn.ModuleList(convs)
113
+ self.bns = nn.ModuleList(bns)
114
+ self.fuse_models = nn.ModuleList(fuse_models)
115
+ self.relu = ReLU(inplace=True)
116
+
117
+ self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
118
+ self.bn3 = nn.BatchNorm2d(planes*self.expansion)
119
+ self.shortcut = nn.Sequential()
120
+ if stride != 1 or in_planes != self.expansion * planes:
121
+ self.shortcut = nn.Sequential(
122
+ nn.Conv2d(in_planes,
123
+ self.expansion * planes,
124
+ kernel_size=1,
125
+ stride=stride,
126
+ bias=False),
127
+ nn.BatchNorm2d(self.expansion * planes))
128
+ self.stride = stride
129
+ self.width = width
130
+ self.scale = scale
131
+
132
+ def forward(self, x):
133
+ residual = x
134
+
135
+ out = self.conv1(x)
136
+ out = self.bn1(out)
137
+ out = self.relu(out)
138
+ spx = torch.split(out,self.width,1)
139
+ for i in range(self.nums):
140
+ if i==0:
141
+ sp = spx[i]
142
+ else:
143
+ sp = self.fuse_models[i-1](sp, spx[i])
144
+
145
+ sp = self.convs[i](sp)
146
+ sp = self.relu(self.bns[i](sp))
147
+ if i==0:
148
+ out = sp
149
+ else:
150
+ out = torch.cat((out,sp),1)
151
+
152
+ out = self.conv3(out)
153
+ out = self.bn3(out)
154
+
155
+ residual = self.shortcut(x)
156
+ out += residual
157
+ out = self.relu(out)
158
+
159
+ return out
160
+
161
+ class ERes2NetV2(nn.Module):
162
+ def __init__(self,
163
+ block=BasicBlockERes2NetV2,
164
+ block_fuse=BasicBlockERes2NetV2AFF,
165
+ num_blocks=[3, 4, 6, 3],
166
+ m_channels=64,
167
+ feat_dim=80,
168
+ embedding_size=192,
169
+ baseWidth=26,
170
+ scale=2,
171
+ expansion=2,
172
+ pooling_func='TSTP',
173
+ two_emb_layer=False):
174
+ super(ERes2NetV2, self).__init__()
175
+ self.in_planes = m_channels
176
+ self.feat_dim = feat_dim
177
+ self.embedding_size = embedding_size
178
+ self.stats_dim = int(feat_dim / 8) * m_channels * 8
179
+ self.two_emb_layer = two_emb_layer
180
+ self.baseWidth = baseWidth
181
+ self.scale = scale
182
+ self.expansion = expansion
183
+
184
+ self.conv1 = nn.Conv2d(1,
185
+ m_channels,
186
+ kernel_size=3,
187
+ stride=1,
188
+ padding=1,
189
+ bias=False)
190
+ self.bn1 = nn.BatchNorm2d(m_channels)
191
+ self.layer1 = self._make_layer(block,
192
+ m_channels,
193
+ num_blocks[0],
194
+ stride=1)
195
+ self.layer2 = self._make_layer(block,
196
+ m_channels * 2,
197
+ num_blocks[1],
198
+ stride=2)
199
+ self.layer3 = self._make_layer(block_fuse,
200
+ m_channels * 4,
201
+ num_blocks[2],
202
+ stride=2)
203
+ self.layer4 = self._make_layer(block_fuse,
204
+ m_channels * 8,
205
+ num_blocks[3],
206
+ stride=2)
207
+
208
+ # Downsampling module
209
+ self.layer3_ds = nn.Conv2d(m_channels * 4 * self.expansion, m_channels * 8 * self.expansion, kernel_size=3, \
210
+ padding=1, stride=2, bias=False)
211
+
212
+ # Bottom-up fusion module
213
+ self.fuse34 = AFF(channels=m_channels * 8 * self.expansion, r=4)
214
+
215
+ self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
216
+ self.pool = getattr(pooling_layers, pooling_func)(
217
+ in_dim=self.stats_dim * self.expansion)
218
+ self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats,
219
+ embedding_size)
220
+ if self.two_emb_layer:
221
+ self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
222
+ self.seg_2 = nn.Linear(embedding_size, embedding_size)
223
+ else:
224
+ self.seg_bn_1 = nn.Identity()
225
+ self.seg_2 = nn.Identity()
226
+
227
+ def _make_layer(self, block, planes, num_blocks, stride):
228
+ strides = [stride] + [1] * (num_blocks - 1)
229
+ layers = []
230
+ for stride in strides:
231
+ layers.append(block(self.in_planes, planes, stride, baseWidth=self.baseWidth, scale=self.scale, expansion=self.expansion))
232
+ self.in_planes = planes * self.expansion
233
+ return nn.Sequential(*layers)
234
+
235
+ def forward(self, x):
236
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
237
+ x = x.unsqueeze_(1)
238
+ out = F.relu(self.bn1(self.conv1(x)))
239
+ out1 = self.layer1(out)
240
+ out2 = self.layer2(out1)
241
+ out3 = self.layer3(out2)
242
+ out4 = self.layer4(out3)
243
+ out3_ds = self.layer3_ds(out3)
244
+ fuse_out34 = self.fuse34(out4, out3_ds)
245
+ stats = self.pool(fuse_out34)
246
+
247
+ embed_a = self.seg_1(stats)
248
+ if self.two_emb_layer:
249
+ out = F.relu(embed_a)
250
+ out = self.seg_bn_1(out)
251
+ embed_b = self.seg_2(out)
252
+ return embed_b
253
+ else:
254
+ return embed_a
255
+
256
+ def forward3(self, x):
257
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
258
+ x = x.unsqueeze_(1)
259
+ out = F.relu(self.bn1(self.conv1(x)))
260
+ out1 = self.layer1(out)
261
+ out2 = self.layer2(out1)
262
+ out3 = self.layer3(out2)
263
+ out4 = self.layer4(out3)
264
+ out3_ds = self.layer3_ds(out3)
265
+ fuse_out34 = self.fuse34(out4, out3_ds)
266
+ # print(111111111,fuse_out34.shape)#111111111 torch.Size([16, 2048, 10, 72])
267
+ return fuse_out34.flatten(start_dim=1,end_dim=2).mean(-1)
268
+ # stats = self.pool(fuse_out34)
269
+ #
270
+ # embed_a = self.seg_1(stats)
271
+ # if self.two_emb_layer:
272
+ # out = F.relu(embed_a)
273
+ # out = self.seg_bn_1(out)
274
+ # embed_b = self.seg_2(out)
275
+ # return embed_b
276
+ # else:
277
+ # return embed_a
278
+
279
+ if __name__ == '__main__':
280
+
281
+ x = torch.randn(1, 300, 80)
282
+ model = ERes2NetV2(feat_dim=80, embedding_size=192, m_channels=64, baseWidth=26, scale=2, expansion=2)
283
+ model.eval()
284
+ y = model(x)
285
+ print(y.size())
286
+ macs, num_params = profile(model, inputs=(x, ))
287
+ print("Params: {} M".format(num_params / 1e6)) # 17.86 M
288
+ print("MACs: {} G".format(macs / 1e9)) # 12.69 G
289
+
290
+
291
+
292
+
eres2net/ERes2Net_huge.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ """ Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
5
+ ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
6
+ The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
7
+ The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
8
+ ERes2Net-huge is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
9
+ recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
10
+ """
11
+ import pdb
12
+
13
+ import torch
14
+ import math
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import pooling_layers as pooling_layers
18
+ from fusion import AFF
19
+
20
+ class ReLU(nn.Hardtanh):
21
+
22
+ def __init__(self, inplace=False):
23
+ super(ReLU, self).__init__(0, 20, inplace)
24
+
25
+ def __repr__(self):
26
+ inplace_str = 'inplace' if self.inplace else ''
27
+ return self.__class__.__name__ + ' (' \
28
+ + inplace_str + ')'
29
+
30
+
31
+ class BasicBlockERes2Net(nn.Module):
32
+ expansion = 4
33
+
34
+ def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
35
+ super(BasicBlockERes2Net, self).__init__()
36
+ width = int(math.floor(planes*(baseWidth/64.0)))
37
+ self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
38
+ self.bn1 = nn.BatchNorm2d(width*scale)
39
+ self.nums = scale
40
+
41
+ convs=[]
42
+ bns=[]
43
+ for i in range(self.nums):
44
+ convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
45
+ bns.append(nn.BatchNorm2d(width))
46
+ self.convs = nn.ModuleList(convs)
47
+ self.bns = nn.ModuleList(bns)
48
+ self.relu = ReLU(inplace=True)
49
+
50
+ self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
51
+ self.bn3 = nn.BatchNorm2d(planes*self.expansion)
52
+ self.shortcut = nn.Sequential()
53
+ if stride != 1 or in_planes != self.expansion * planes:
54
+ self.shortcut = nn.Sequential(
55
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
56
+ nn.BatchNorm2d(self.expansion * planes))
57
+ self.stride = stride
58
+ self.width = width
59
+ self.scale = scale
60
+
61
+ def forward(self, x):
62
+ residual = x
63
+
64
+ out = self.conv1(x)
65
+ out = self.bn1(out)
66
+ out = self.relu(out)
67
+ spx = torch.split(out,self.width,1)
68
+ for i in range(self.nums):
69
+ if i==0:
70
+ sp = spx[i]
71
+ else:
72
+ sp = sp + spx[i]
73
+ sp = self.convs[i](sp)
74
+ sp = self.relu(self.bns[i](sp))
75
+ if i==0:
76
+ out = sp
77
+ else:
78
+ out = torch.cat((out,sp),1)
79
+
80
+ out = self.conv3(out)
81
+ out = self.bn3(out)
82
+
83
+ residual = self.shortcut(x)
84
+ out += residual
85
+ out = self.relu(out)
86
+
87
+ return out
88
+
89
+ class BasicBlockERes2Net_diff_AFF(nn.Module):
90
+ expansion = 4
91
+
92
+ def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
93
+ super(BasicBlockERes2Net_diff_AFF, self).__init__()
94
+ width = int(math.floor(planes*(baseWidth/64.0)))
95
+ self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
96
+ self.bn1 = nn.BatchNorm2d(width*scale)
97
+ self.nums = scale
98
+
99
+ convs=[]
100
+ fuse_models=[]
101
+ bns=[]
102
+ for i in range(self.nums):
103
+ convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
104
+ bns.append(nn.BatchNorm2d(width))
105
+ for j in range(self.nums - 1):
106
+ fuse_models.append(AFF(channels=width))
107
+
108
+ self.convs = nn.ModuleList(convs)
109
+ self.bns = nn.ModuleList(bns)
110
+ self.fuse_models = nn.ModuleList(fuse_models)
111
+ self.relu = ReLU(inplace=True)
112
+
113
+ self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
114
+ self.bn3 = nn.BatchNorm2d(planes*self.expansion)
115
+ self.shortcut = nn.Sequential()
116
+ if stride != 1 or in_planes != self.expansion * planes:
117
+ self.shortcut = nn.Sequential(
118
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
119
+ nn.BatchNorm2d(self.expansion * planes))
120
+ self.stride = stride
121
+ self.width = width
122
+ self.scale = scale
123
+
124
+ def forward(self, x):
125
+ residual = x
126
+
127
+ out = self.conv1(x)
128
+ out = self.bn1(out)
129
+ out = self.relu(out)
130
+ spx = torch.split(out,self.width,1)
131
+ for i in range(self.nums):
132
+ if i==0:
133
+ sp = spx[i]
134
+ else:
135
+ sp = self.fuse_models[i-1](sp, spx[i])
136
+
137
+ sp = self.convs[i](sp)
138
+ sp = self.relu(self.bns[i](sp))
139
+ if i==0:
140
+ out = sp
141
+ else:
142
+ out = torch.cat((out,sp),1)
143
+
144
+
145
+ out = self.conv3(out)
146
+ out = self.bn3(out)
147
+
148
+ residual = self.shortcut(x)
149
+ out += residual
150
+ out = self.relu(out)
151
+
152
+ return out
153
+
154
+ class ERes2Net(nn.Module):
155
+ def __init__(self,
156
+ block=BasicBlockERes2Net,
157
+ block_fuse=BasicBlockERes2Net_diff_AFF,
158
+ num_blocks=[3, 4, 6, 3],
159
+ m_channels=64,
160
+ feat_dim=80,
161
+ embedding_size=192,
162
+ pooling_func='TSTP',
163
+ two_emb_layer=False):
164
+ super(ERes2Net, self).__init__()
165
+ self.in_planes = m_channels
166
+ self.feat_dim = feat_dim
167
+ self.embedding_size = embedding_size
168
+ self.stats_dim = int(feat_dim / 8) * m_channels * 8
169
+ self.two_emb_layer = two_emb_layer
170
+
171
+ self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
172
+ self.bn1 = nn.BatchNorm2d(m_channels)
173
+
174
+ self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
175
+ self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
176
+ self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
177
+ self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
178
+
179
+ self.layer1_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False)
180
+ self.layer2_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False)
181
+ self.layer3_downsample = nn.Conv2d(m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2, bias=False)
182
+
183
+ self.fuse_mode12 = AFF(channels=m_channels * 8)
184
+ self.fuse_mode123 = AFF(channels=m_channels * 16)
185
+ self.fuse_mode1234 = AFF(channels=m_channels * 32)
186
+
187
+ self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
188
+ self.pool = getattr(pooling_layers, pooling_func)(
189
+ in_dim=self.stats_dim * block.expansion)
190
+ self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
191
+ if self.two_emb_layer:
192
+ self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
193
+ self.seg_2 = nn.Linear(embedding_size, embedding_size)
194
+ else:
195
+ self.seg_bn_1 = nn.Identity()
196
+ self.seg_2 = nn.Identity()
197
+
198
+ def _make_layer(self, block, planes, num_blocks, stride):
199
+ strides = [stride] + [1] * (num_blocks - 1)
200
+ layers = []
201
+ for stride in strides:
202
+ layers.append(block(self.in_planes, planes, stride))
203
+ self.in_planes = planes * block.expansion
204
+ return nn.Sequential(*layers)
205
+
206
+ def forward(self, x):
207
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
208
+
209
+ x = x.unsqueeze_(1)
210
+ out = F.relu(self.bn1(self.conv1(x)))
211
+ out1 = self.layer1(out)
212
+ out2 = self.layer2(out1)
213
+ out1_downsample = self.layer1_downsample(out1)
214
+ fuse_out12 = self.fuse_mode12(out2, out1_downsample)
215
+ out3 = self.layer3(out2)
216
+ fuse_out12_downsample = self.layer2_downsample(fuse_out12)
217
+ fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
218
+ out4 = self.layer4(out3)
219
+ fuse_out123_downsample = self.layer3_downsample(fuse_out123)
220
+ fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
221
+ stats = self.pool(fuse_out1234)
222
+
223
+ embed_a = self.seg_1(stats)
224
+ if self.two_emb_layer:
225
+ out = F.relu(embed_a)
226
+ out = self.seg_bn_1(out)
227
+ embed_b = self.seg_2(out)
228
+ return embed_b
229
+ else:
230
+ return embed_a
231
+
232
+ def forward2(self, x,if_mean):
233
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
234
+
235
+ x = x.unsqueeze_(1)
236
+ out = F.relu(self.bn1(self.conv1(x)))
237
+ out1 = self.layer1(out)
238
+ out2 = self.layer2(out1)
239
+ out1_downsample = self.layer1_downsample(out1)
240
+ fuse_out12 = self.fuse_mode12(out2, out1_downsample)
241
+ out3 = self.layer3(out2)
242
+ fuse_out12_downsample = self.layer2_downsample(fuse_out12)
243
+ fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
244
+ out4 = self.layer4(out3)
245
+ fuse_out123_downsample = self.layer3_downsample(fuse_out123)
246
+ fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1,end_dim=2)#bs,20480,T
247
+ if(if_mean==False):
248
+ mean=fuse_out1234[0].transpose(1,0)#(T,20480),bs=T
249
+ else:
250
+ mean = fuse_out1234.mean(2)#bs,20480
251
+ mean_std=torch.cat([mean,torch.zeros_like(mean)],1)
252
+ return self.seg_1(mean_std)#(T,192)
253
+
254
+
255
+ # stats = self.pool(fuse_out1234)
256
+ # if self.two_emb_layer:
257
+ # out = F.relu(embed_a)
258
+ # out = self.seg_bn_1(out)
259
+ # embed_b = self.seg_2(out)
260
+ # return embed_b
261
+ # else:
262
+ # return embed_a
263
+
264
+ def forward3(self, x):
265
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
266
+
267
+ x = x.unsqueeze_(1)
268
+ out = F.relu(self.bn1(self.conv1(x)))
269
+ out1 = self.layer1(out)
270
+ out2 = self.layer2(out1)
271
+ out1_downsample = self.layer1_downsample(out1)
272
+ fuse_out12 = self.fuse_mode12(out2, out1_downsample)
273
+ out3 = self.layer3(out2)
274
+ fuse_out12_downsample = self.layer2_downsample(fuse_out12)
275
+ fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
276
+ out4 = self.layer4(out3)
277
+ fuse_out123_downsample = self.layer3_downsample(fuse_out123)
278
+ fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1,end_dim=2).mean(-1)
279
+ return fuse_out1234
280
+ # print(fuse_out1234.shape)
281
+ # print(fuse_out1234.flatten(start_dim=1,end_dim=2).shape)
282
+ # pdb.set_trace()
283
+
284
+
285
+
286
+
eres2net/fusion.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ class AFF(nn.Module):
9
+
10
+ def __init__(self, channels=64, r=4):
11
+ super(AFF, self).__init__()
12
+ inter_channels = int(channels // r)
13
+
14
+ self.local_att = nn.Sequential(
15
+ nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0),
16
+ nn.BatchNorm2d(inter_channels),
17
+ nn.SiLU(inplace=True),
18
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
19
+ nn.BatchNorm2d(channels),
20
+ )
21
+
22
+ def forward(self, x, ds_y):
23
+ xa = torch.cat((x, ds_y), dim=1)
24
+ x_att = self.local_att(xa)
25
+ x_att = 1.0 + torch.tanh(x_att)
26
+ xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0-x_att)
27
+
28
+ return xo
29
+
eres2net/kaldi.py ADDED
@@ -0,0 +1,819 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torchaudio
6
+ from torch import Tensor
7
+
8
+ __all__ = [
9
+ "get_mel_banks",
10
+ "inverse_mel_scale",
11
+ "inverse_mel_scale_scalar",
12
+ "mel_scale",
13
+ "mel_scale_scalar",
14
+ "spectrogram",
15
+ "fbank",
16
+ "mfcc",
17
+ "vtln_warp_freq",
18
+ "vtln_warp_mel_freq",
19
+ ]
20
+
21
+ # numeric_limits<float>::epsilon() 1.1920928955078125e-07
22
+ EPSILON = torch.tensor(torch.finfo(torch.float).eps)
23
+ # 1 milliseconds = 0.001 seconds
24
+ MILLISECONDS_TO_SECONDS = 0.001
25
+
26
+ # window types
27
+ HAMMING = "hamming"
28
+ HANNING = "hanning"
29
+ POVEY = "povey"
30
+ RECTANGULAR = "rectangular"
31
+ BLACKMAN = "blackman"
32
+ WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
33
+
34
+
35
+ def _get_epsilon(device, dtype):
36
+ return EPSILON.to(device=device, dtype=dtype)
37
+
38
+
39
+ def _next_power_of_2(x: int) -> int:
40
+ r"""Returns the smallest power of 2 that is greater than x"""
41
+ return 1 if x == 0 else 2 ** (x - 1).bit_length()
42
+
43
+
44
+ def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor:
45
+ r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``)
46
+ representing how the window is shifted along the waveform. Each row is a frame.
47
+
48
+ Args:
49
+ waveform (Tensor): Tensor of size ``num_samples``
50
+ window_size (int): Frame length
51
+ window_shift (int): Frame shift
52
+ snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
53
+ in the file, and the number of frames depends on the frame_length. If False, the number of frames
54
+ depends only on the frame_shift, and we reflect the data at the ends.
55
+
56
+ Returns:
57
+ Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame
58
+ """
59
+ assert waveform.dim() == 1
60
+ num_samples = waveform.size(0)
61
+ strides = (window_shift * waveform.stride(0), waveform.stride(0))
62
+
63
+ if snip_edges:
64
+ if num_samples < window_size:
65
+ return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
66
+ else:
67
+ m = 1 + (num_samples - window_size) // window_shift
68
+ else:
69
+ reversed_waveform = torch.flip(waveform, [0])
70
+ m = (num_samples + (window_shift // 2)) // window_shift
71
+ pad = window_size // 2 - window_shift // 2
72
+ pad_right = reversed_waveform
73
+ if pad > 0:
74
+ # torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
75
+ # but we want [2, 1, 0, 0, 1, 2]
76
+ pad_left = reversed_waveform[-pad:]
77
+ waveform = torch.cat((pad_left, waveform, pad_right), dim=0)
78
+ else:
79
+ # pad is negative so we want to trim the waveform at the front
80
+ waveform = torch.cat((waveform[-pad:], pad_right), dim=0)
81
+
82
+ sizes = (m, window_size)
83
+ return waveform.as_strided(sizes, strides)
84
+
85
+
86
+ def _feature_window_function(
87
+ window_type: str,
88
+ window_size: int,
89
+ blackman_coeff: float,
90
+ device: torch.device,
91
+ dtype: int,
92
+ ) -> Tensor:
93
+ r"""Returns a window function with the given type and size"""
94
+ if window_type == HANNING:
95
+ return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)
96
+ elif window_type == HAMMING:
97
+ return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype)
98
+ elif window_type == POVEY:
99
+ # like hanning but goes to zero at edges
100
+ return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85)
101
+ elif window_type == RECTANGULAR:
102
+ return torch.ones(window_size, device=device, dtype=dtype)
103
+ elif window_type == BLACKMAN:
104
+ a = 2 * math.pi / (window_size - 1)
105
+ window_function = torch.arange(window_size, device=device, dtype=dtype)
106
+ # can't use torch.blackman_window as they use different coefficients
107
+ return (
108
+ blackman_coeff
109
+ - 0.5 * torch.cos(a * window_function)
110
+ + (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)
111
+ ).to(device=device, dtype=dtype)
112
+ else:
113
+ raise Exception("Invalid window type " + window_type)
114
+
115
+
116
+ def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor:
117
+ r"""Returns the log energy of size (m) for a strided_input (m,*)"""
118
+ device, dtype = strided_input.device, strided_input.dtype
119
+ log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
120
+ if energy_floor == 0.0:
121
+ return log_energy
122
+ return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))
123
+
124
+
125
+ def _get_waveform_and_window_properties(
126
+ waveform: Tensor,
127
+ channel: int,
128
+ sample_frequency: float,
129
+ frame_shift: float,
130
+ frame_length: float,
131
+ round_to_power_of_two: bool,
132
+ preemphasis_coefficient: float,
133
+ ) -> Tuple[Tensor, int, int, int]:
134
+ r"""Gets the waveform and window properties"""
135
+ channel = max(channel, 0)
136
+ assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0))
137
+ waveform = waveform[channel, :] # size (n)
138
+ window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
139
+ window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
140
+ padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
141
+
142
+ assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
143
+ window_size, len(waveform)
144
+ )
145
+ assert 0 < window_shift, "`window_shift` must be greater than 0"
146
+ assert padded_window_size % 2 == 0, (
147
+ "the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`"
148
+ )
149
+ assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
150
+ assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
151
+ return waveform, window_shift, window_size, padded_window_size
152
+
153
+
154
+ def _get_window(
155
+ waveform: Tensor,
156
+ padded_window_size: int,
157
+ window_size: int,
158
+ window_shift: int,
159
+ window_type: str,
160
+ blackman_coeff: float,
161
+ snip_edges: bool,
162
+ raw_energy: bool,
163
+ energy_floor: float,
164
+ dither: float,
165
+ remove_dc_offset: bool,
166
+ preemphasis_coefficient: float,
167
+ ) -> Tuple[Tensor, Tensor]:
168
+ r"""Gets a window and its log energy
169
+
170
+ Returns:
171
+ (Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m)
172
+ """
173
+ device, dtype = waveform.device, waveform.dtype
174
+ epsilon = _get_epsilon(device, dtype)
175
+
176
+ # size (m, window_size)
177
+ strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
178
+
179
+ if dither != 0.0:
180
+ rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype)
181
+ strided_input = strided_input + rand_gauss * dither
182
+
183
+ if remove_dc_offset:
184
+ # Subtract each row/frame by its mean
185
+ row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
186
+ strided_input = strided_input - row_means
187
+
188
+ if raw_energy:
189
+ # Compute the log energy of each row/frame before applying preemphasis and
190
+ # window function
191
+ signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
192
+
193
+ if preemphasis_coefficient != 0.0:
194
+ # strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
195
+ offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(
196
+ 0
197
+ ) # size (m, window_size + 1)
198
+ strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]
199
+
200
+ # Apply window_function to each row/frame
201
+ window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze(
202
+ 0
203
+ ) # size (1, window_size)
204
+ strided_input = strided_input * window_function # size (m, window_size)
205
+
206
+ # Pad columns with zero until we reach size (m, padded_window_size)
207
+ if padded_window_size != window_size:
208
+ padding_right = padded_window_size - window_size
209
+ strided_input = torch.nn.functional.pad(
210
+ strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0
211
+ ).squeeze(0)
212
+
213
+ # Compute energy after window function (not the raw one)
214
+ if not raw_energy:
215
+ signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
216
+
217
+ return strided_input, signal_log_energy
218
+
219
+
220
+ def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
221
+ # subtracts the column mean of the tensor size (m, n) if subtract_mean=True
222
+ # it returns size (m, n)
223
+ if subtract_mean:
224
+ col_means = torch.mean(tensor, dim=0).unsqueeze(0)
225
+ tensor = tensor - col_means
226
+ return tensor
227
+
228
+
229
+ def spectrogram(
230
+ waveform: Tensor,
231
+ blackman_coeff: float = 0.42,
232
+ channel: int = -1,
233
+ dither: float = 0.0,
234
+ energy_floor: float = 1.0,
235
+ frame_length: float = 25.0,
236
+ frame_shift: float = 10.0,
237
+ min_duration: float = 0.0,
238
+ preemphasis_coefficient: float = 0.97,
239
+ raw_energy: bool = True,
240
+ remove_dc_offset: bool = True,
241
+ round_to_power_of_two: bool = True,
242
+ sample_frequency: float = 16000.0,
243
+ snip_edges: bool = True,
244
+ subtract_mean: bool = False,
245
+ window_type: str = POVEY,
246
+ ) -> Tensor:
247
+ r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's
248
+ compute-spectrogram-feats.
249
+
250
+ Args:
251
+ waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
252
+ blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
253
+ channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
254
+ dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
255
+ the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
256
+ energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
257
+ this floor is applied to the zeroth component, representing the total signal energy. The floor on the
258
+ individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
259
+ frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
260
+ frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
261
+ min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
262
+ preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
263
+ raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
264
+ remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
265
+ round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
266
+ to FFT. (Default: ``True``)
267
+ sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
268
+ specified there) (Default: ``16000.0``)
269
+ snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
270
+ in the file, and the number of frames depends on the frame_length. If False, the number of frames
271
+ depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
272
+ subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
273
+ it this way. (Default: ``False``)
274
+ window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
275
+ (Default: ``'povey'``)
276
+
277
+ Returns:
278
+ Tensor: A spectrogram identical to what Kaldi would output. The shape is
279
+ (m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided
280
+ """
281
+ device, dtype = waveform.device, waveform.dtype
282
+ epsilon = _get_epsilon(device, dtype)
283
+
284
+ waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
285
+ waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
286
+ )
287
+
288
+ if len(waveform) < min_duration * sample_frequency:
289
+ # signal is too short
290
+ return torch.empty(0)
291
+
292
+ strided_input, signal_log_energy = _get_window(
293
+ waveform,
294
+ padded_window_size,
295
+ window_size,
296
+ window_shift,
297
+ window_type,
298
+ blackman_coeff,
299
+ snip_edges,
300
+ raw_energy,
301
+ energy_floor,
302
+ dither,
303
+ remove_dc_offset,
304
+ preemphasis_coefficient,
305
+ )
306
+
307
+ # size (m, padded_window_size // 2 + 1, 2)
308
+ fft = torch.fft.rfft(strided_input)
309
+
310
+ # Convert the FFT into a power spectrum
311
+ power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1)
312
+ power_spectrum[:, 0] = signal_log_energy
313
+
314
+ power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
315
+ return power_spectrum
316
+
317
+
318
+ def inverse_mel_scale_scalar(mel_freq: float) -> float:
319
+ return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
320
+
321
+
322
+ def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
323
+ return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
324
+
325
+
326
+ def mel_scale_scalar(freq: float) -> float:
327
+ return 1127.0 * math.log(1.0 + freq / 700.0)
328
+
329
+
330
+ def mel_scale(freq: Tensor) -> Tensor:
331
+ return 1127.0 * (1.0 + freq / 700.0).log()
332
+
333
+
334
+ def vtln_warp_freq(
335
+ vtln_low_cutoff: float,
336
+ vtln_high_cutoff: float,
337
+ low_freq: float,
338
+ high_freq: float,
339
+ vtln_warp_factor: float,
340
+ freq: Tensor,
341
+ ) -> Tensor:
342
+ r"""This computes a VTLN warping function that is not the same as HTK's one,
343
+ but has similar inputs (this function has the advantage of never producing
344
+ empty bins).
345
+
346
+ This function computes a warp function F(freq), defined between low_freq
347
+ and high_freq inclusive, with the following properties:
348
+ F(low_freq) == low_freq
349
+ F(high_freq) == high_freq
350
+ The function is continuous and piecewise linear with two inflection
351
+ points.
352
+ The lower inflection point (measured in terms of the unwarped
353
+ frequency) is at frequency l, determined as described below.
354
+ The higher inflection point is at a frequency h, determined as
355
+ described below.
356
+ If l <= f <= h, then F(f) = f/vtln_warp_factor.
357
+ If the higher inflection point (measured in terms of the unwarped
358
+ frequency) is at h, then max(h, F(h)) == vtln_high_cutoff.
359
+ Since (by the last point) F(h) == h/vtln_warp_factor, then
360
+ max(h, h/vtln_warp_factor) == vtln_high_cutoff, so
361
+ h = vtln_high_cutoff / max(1, 1/vtln_warp_factor).
362
+ = vtln_high_cutoff * min(1, vtln_warp_factor).
363
+ If the lower inflection point (measured in terms of the unwarped
364
+ frequency) is at l, then min(l, F(l)) == vtln_low_cutoff
365
+ This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor)
366
+ = vtln_low_cutoff * max(1, vtln_warp_factor)
367
+ Args:
368
+ vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
369
+ vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
370
+ low_freq (float): Lower frequency cutoffs in mel computation
371
+ high_freq (float): Upper frequency cutoffs in mel computation
372
+ vtln_warp_factor (float): Vtln warp factor
373
+ freq (Tensor): given frequency in Hz
374
+
375
+ Returns:
376
+ Tensor: Freq after vtln warp
377
+ """
378
+ assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq"
379
+ assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]"
380
+ l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
381
+ h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
382
+ scale = 1.0 / vtln_warp_factor
383
+ Fl = scale * l # F(l)
384
+ Fh = scale * h # F(h)
385
+ assert l > low_freq and h < high_freq
386
+ # slope of left part of the 3-piece linear function
387
+ scale_left = (Fl - low_freq) / (l - low_freq)
388
+ # [slope of center part is just "scale"]
389
+
390
+ # slope of right part of the 3-piece linear function
391
+ scale_right = (high_freq - Fh) / (high_freq - h)
392
+
393
+ res = torch.empty_like(freq)
394
+
395
+ outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq
396
+ before_l = torch.lt(freq, l) # freq < l
397
+ before_h = torch.lt(freq, h) # freq < h
398
+ after_h = torch.ge(freq, h) # freq >= h
399
+
400
+ # order of operations matter here (since there is overlapping frequency regions)
401
+ res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
402
+ res[before_h] = scale * freq[before_h]
403
+ res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
404
+ res[outside_low_high_freq] = freq[outside_low_high_freq]
405
+
406
+ return res
407
+
408
+
409
+ def vtln_warp_mel_freq(
410
+ vtln_low_cutoff: float,
411
+ vtln_high_cutoff: float,
412
+ low_freq,
413
+ high_freq: float,
414
+ vtln_warp_factor: float,
415
+ mel_freq: Tensor,
416
+ ) -> Tensor:
417
+ r"""
418
+ Args:
419
+ vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
420
+ vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
421
+ low_freq (float): Lower frequency cutoffs in mel computation
422
+ high_freq (float): Upper frequency cutoffs in mel computation
423
+ vtln_warp_factor (float): Vtln warp factor
424
+ mel_freq (Tensor): Given frequency in Mel
425
+
426
+ Returns:
427
+ Tensor: ``mel_freq`` after vtln warp
428
+ """
429
+ return mel_scale(
430
+ vtln_warp_freq(
431
+ vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq)
432
+ )
433
+ )
434
+
435
+
436
+ def get_mel_banks(
437
+ num_bins: int,
438
+ window_length_padded: int,
439
+ sample_freq: float,
440
+ low_freq: float,
441
+ high_freq: float,
442
+ vtln_low: float,
443
+ vtln_high: float,
444
+ vtln_warp_factor: float,device=None,dtype=None
445
+ ) -> Tuple[Tensor, Tensor]:
446
+ """
447
+ Returns:
448
+ (Tensor, Tensor): The tuple consists of ``bins`` (which is
449
+ melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is
450
+ center frequencies of bins of size (``num_bins``)).
451
+ """
452
+ assert num_bins > 3, "Must have at least 3 mel bins"
453
+ assert window_length_padded % 2 == 0
454
+ num_fft_bins = window_length_padded / 2
455
+ nyquist = 0.5 * sample_freq
456
+
457
+ if high_freq <= 0.0:
458
+ high_freq += nyquist
459
+
460
+ assert (
461
+ (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq)
462
+ ), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
463
+
464
+ # fft-bin width [think of it as Nyquist-freq / half-window-length]
465
+ fft_bin_width = sample_freq / window_length_padded
466
+ mel_low_freq = mel_scale_scalar(low_freq)
467
+ mel_high_freq = mel_scale_scalar(high_freq)
468
+
469
+ # divide by num_bins+1 in next line because of end-effects where the bins
470
+ # spread out to the sides.
471
+ mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
472
+
473
+ if vtln_high < 0.0:
474
+ vtln_high += nyquist
475
+
476
+ assert vtln_warp_factor == 1.0 or (
477
+ (low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)
478
+ ), "Bad values in options: vtln-low {} and vtln-high {}, versus " "low-freq {} and high-freq {}".format(
479
+ vtln_low, vtln_high, low_freq, high_freq
480
+ )
481
+
482
+ bin = torch.arange(num_bins).unsqueeze(1)
483
+ left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
484
+ center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1)
485
+ right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1)
486
+
487
+ if vtln_warp_factor != 1.0:
488
+ left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel)
489
+ center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel)
490
+ right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel)
491
+
492
+ # center_freqs = inverse_mel_scale(center_mel) # size (num_bins)
493
+ # size(1, num_fft_bins)
494
+ mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0)
495
+
496
+ # size (num_bins, num_fft_bins)
497
+ up_slope = (mel - left_mel) / (center_mel - left_mel)
498
+ down_slope = (right_mel - mel) / (right_mel - center_mel)
499
+
500
+ if vtln_warp_factor == 1.0:
501
+ # left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
502
+ bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
503
+ else:
504
+ # warping can move the order of left_mel, center_mel, right_mel anywhere
505
+ bins = torch.zeros_like(up_slope)
506
+ up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel
507
+ down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel
508
+ bins[up_idx] = up_slope[up_idx]
509
+ bins[down_idx] = down_slope[down_idx]
510
+
511
+ return bins.to(device=device,dtype=dtype)#, center_freqs
512
+
513
+ cache={}
514
+ def fbank(
515
+ waveform: Tensor,
516
+ blackman_coeff: float = 0.42,
517
+ channel: int = -1,
518
+ dither: float = 0.0,
519
+ energy_floor: float = 1.0,
520
+ frame_length: float = 25.0,
521
+ frame_shift: float = 10.0,
522
+ high_freq: float = 0.0,
523
+ htk_compat: bool = False,
524
+ low_freq: float = 20.0,
525
+ min_duration: float = 0.0,
526
+ num_mel_bins: int = 23,
527
+ preemphasis_coefficient: float = 0.97,
528
+ raw_energy: bool = True,
529
+ remove_dc_offset: bool = True,
530
+ round_to_power_of_two: bool = True,
531
+ sample_frequency: float = 16000.0,
532
+ snip_edges: bool = True,
533
+ subtract_mean: bool = False,
534
+ use_energy: bool = False,
535
+ use_log_fbank: bool = True,
536
+ use_power: bool = True,
537
+ vtln_high: float = -500.0,
538
+ vtln_low: float = 100.0,
539
+ vtln_warp: float = 1.0,
540
+ window_type: str = POVEY,
541
+ ) -> Tensor:
542
+ r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's
543
+ compute-fbank-feats.
544
+
545
+ Args:
546
+ waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
547
+ blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
548
+ channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
549
+ dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
550
+ the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
551
+ energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
552
+ this floor is applied to the zeroth component, representing the total signal energy. The floor on the
553
+ individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
554
+ frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
555
+ frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
556
+ high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
557
+ (Default: ``0.0``)
558
+ htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features
559
+ (need to change other parameters). (Default: ``False``)
560
+ low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
561
+ min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
562
+ num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
563
+ preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
564
+ raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
565
+ remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
566
+ round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
567
+ to FFT. (Default: ``True``)
568
+ sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
569
+ specified there) (Default: ``16000.0``)
570
+ snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
571
+ in the file, and the number of frames depends on the frame_length. If False, the number of frames
572
+ depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
573
+ subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
574
+ it this way. (Default: ``False``)
575
+ use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
576
+ use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``)
577
+ use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``)
578
+ vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
579
+ negative, offset from high-mel-freq (Default: ``-500.0``)
580
+ vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
581
+ vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
582
+ window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
583
+ (Default: ``'povey'``)
584
+
585
+ Returns:
586
+ Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``)
587
+ where m is calculated in _get_strided
588
+ """
589
+ device, dtype = waveform.device, waveform.dtype
590
+
591
+ waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
592
+ waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
593
+ )
594
+
595
+ if len(waveform) < min_duration * sample_frequency:
596
+ # signal is too short
597
+ return torch.empty(0, device=device, dtype=dtype)
598
+
599
+ # strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
600
+ strided_input, signal_log_energy = _get_window(
601
+ waveform,
602
+ padded_window_size,
603
+ window_size,
604
+ window_shift,
605
+ window_type,
606
+ blackman_coeff,
607
+ snip_edges,
608
+ raw_energy,
609
+ energy_floor,
610
+ dither,
611
+ remove_dc_offset,
612
+ preemphasis_coefficient,
613
+ )
614
+
615
+ # size (m, padded_window_size // 2 + 1)
616
+ spectrum = torch.fft.rfft(strided_input).abs()
617
+ if use_power:
618
+ spectrum = spectrum.pow(2.0)
619
+
620
+ # size (num_mel_bins, padded_window_size // 2)
621
+ # print(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp)
622
+
623
+ cache_key="%s-%s-%s-%s-%s-%s-%s-%s-%s-%s"%(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp,device,dtype)
624
+ if cache_key not in cache:
625
+ mel_energies = get_mel_banks(
626
+ num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp,device,dtype
627
+ )
628
+ cache[cache_key]=mel_energies
629
+ else:
630
+ mel_energies=cache[cache_key]
631
+
632
+ # pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
633
+ mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
634
+
635
+ # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
636
+ mel_energies = torch.mm(spectrum, mel_energies.T)
637
+ if use_log_fbank:
638
+ # avoid log of zero (which should be prevented anyway by dithering)
639
+ mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
640
+
641
+ # if use_energy then add it as the last column for htk_compat == true else first column
642
+ if use_energy:
643
+ signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1)
644
+ # returns size (m, num_mel_bins + 1)
645
+ if htk_compat:
646
+ mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1)
647
+ else:
648
+ mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1)
649
+
650
+ mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
651
+ return mel_energies
652
+
653
+
654
+ def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor:
655
+ # returns a dct matrix of size (num_mel_bins, num_ceps)
656
+ # size (num_mel_bins, num_mel_bins)
657
+ dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho")
658
+ # kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
659
+ # this would be the first column in the dct_matrix for torchaudio as it expects a
660
+ # right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
661
+ # expects a left multiply e.g. dct_matrix * vector).
662
+ dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins))
663
+ dct_matrix = dct_matrix[:, :num_ceps]
664
+ return dct_matrix
665
+
666
+
667
+ def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor:
668
+ # returns size (num_ceps)
669
+ # Compute liftering coefficients (scaling on cepstral coeffs)
670
+ # coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
671
+ i = torch.arange(num_ceps)
672
+ return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter)
673
+
674
+
675
+ def mfcc(
676
+ waveform: Tensor,
677
+ blackman_coeff: float = 0.42,
678
+ cepstral_lifter: float = 22.0,
679
+ channel: int = -1,
680
+ dither: float = 0.0,
681
+ energy_floor: float = 1.0,
682
+ frame_length: float = 25.0,
683
+ frame_shift: float = 10.0,
684
+ high_freq: float = 0.0,
685
+ htk_compat: bool = False,
686
+ low_freq: float = 20.0,
687
+ num_ceps: int = 13,
688
+ min_duration: float = 0.0,
689
+ num_mel_bins: int = 23,
690
+ preemphasis_coefficient: float = 0.97,
691
+ raw_energy: bool = True,
692
+ remove_dc_offset: bool = True,
693
+ round_to_power_of_two: bool = True,
694
+ sample_frequency: float = 16000.0,
695
+ snip_edges: bool = True,
696
+ subtract_mean: bool = False,
697
+ use_energy: bool = False,
698
+ vtln_high: float = -500.0,
699
+ vtln_low: float = 100.0,
700
+ vtln_warp: float = 1.0,
701
+ window_type: str = POVEY,
702
+ ) -> Tensor:
703
+ r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's
704
+ compute-mfcc-feats.
705
+
706
+ Args:
707
+ waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
708
+ blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
709
+ cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``)
710
+ channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
711
+ dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
712
+ the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
713
+ energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
714
+ this floor is applied to the zeroth component, representing the total signal energy. The floor on the
715
+ individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
716
+ frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
717
+ frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
718
+ high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
719
+ (Default: ``0.0``)
720
+ htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible
721
+ features (need to change other parameters). (Default: ``False``)
722
+ low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
723
+ num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``)
724
+ min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
725
+ num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
726
+ preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
727
+ raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
728
+ remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
729
+ round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
730
+ to FFT. (Default: ``True``)
731
+ sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
732
+ specified there) (Default: ``16000.0``)
733
+ snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
734
+ in the file, and the number of frames depends on the frame_length. If False, the number of frames
735
+ depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
736
+ subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
737
+ it this way. (Default: ``False``)
738
+ use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
739
+ vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
740
+ negative, offset from high-mel-freq (Default: ``-500.0``)
741
+ vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
742
+ vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
743
+ window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
744
+ (Default: ``"povey"``)
745
+
746
+ Returns:
747
+ Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``)
748
+ where m is calculated in _get_strided
749
+ """
750
+ assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins)
751
+
752
+ device, dtype = waveform.device, waveform.dtype
753
+
754
+ # The mel_energies should not be squared (use_power=True), not have mean subtracted
755
+ # (subtract_mean=False), and use log (use_log_fbank=True).
756
+ # size (m, num_mel_bins + use_energy)
757
+ feature = fbank(
758
+ waveform=waveform,
759
+ blackman_coeff=blackman_coeff,
760
+ channel=channel,
761
+ dither=dither,
762
+ energy_floor=energy_floor,
763
+ frame_length=frame_length,
764
+ frame_shift=frame_shift,
765
+ high_freq=high_freq,
766
+ htk_compat=htk_compat,
767
+ low_freq=low_freq,
768
+ min_duration=min_duration,
769
+ num_mel_bins=num_mel_bins,
770
+ preemphasis_coefficient=preemphasis_coefficient,
771
+ raw_energy=raw_energy,
772
+ remove_dc_offset=remove_dc_offset,
773
+ round_to_power_of_two=round_to_power_of_two,
774
+ sample_frequency=sample_frequency,
775
+ snip_edges=snip_edges,
776
+ subtract_mean=False,
777
+ use_energy=use_energy,
778
+ use_log_fbank=True,
779
+ use_power=True,
780
+ vtln_high=vtln_high,
781
+ vtln_low=vtln_low,
782
+ vtln_warp=vtln_warp,
783
+ window_type=window_type,
784
+ )
785
+
786
+ if use_energy:
787
+ # size (m)
788
+ signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
789
+ # offset is 0 if htk_compat==True else 1
790
+ mel_offset = int(not htk_compat)
791
+ feature = feature[:, mel_offset : (num_mel_bins + mel_offset)]
792
+
793
+ # size (num_mel_bins, num_ceps)
794
+ dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device)
795
+
796
+ # size (m, num_ceps)
797
+ feature = feature.matmul(dct_matrix)
798
+
799
+ if cepstral_lifter != 0.0:
800
+ # size (1, num_ceps)
801
+ lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0)
802
+ feature *= lifter_coeffs.to(device=device, dtype=dtype)
803
+
804
+ # if use_energy then replace the last column for htk_compat == true else first column
805
+ if use_energy:
806
+ feature[:, 0] = signal_log_energy
807
+
808
+ if htk_compat:
809
+ energy = feature[:, 0].unsqueeze(1) # size (m, 1)
810
+ feature = feature[:, 1:] # size (m, num_ceps - 1)
811
+ if not use_energy:
812
+ # scale on C0 (actually removing a scale we previously added that's
813
+ # part of one common definition of the cosine transform.)
814
+ energy *= math.sqrt(2)
815
+
816
+ feature = torch.cat((feature, energy), dim=1)
817
+
818
+ feature = _subtract_column_mean(feature, subtract_mean)
819
+ return feature
eres2net/pooling_layers.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ """ This implementation is adapted from https://github.com/wenet-e2e/wespeaker."""
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class TAP(nn.Module):
11
+ """
12
+ Temporal average pooling, only first-order mean is considered
13
+ """
14
+ def __init__(self, **kwargs):
15
+ super(TAP, self).__init__()
16
+
17
+ def forward(self, x):
18
+ pooling_mean = x.mean(dim=-1)
19
+ # To be compatable with 2D input
20
+ pooling_mean = pooling_mean.flatten(start_dim=1)
21
+ return pooling_mean
22
+
23
+
24
+ class TSDP(nn.Module):
25
+ """
26
+ Temporal standard deviation pooling, only second-order std is considered
27
+ """
28
+ def __init__(self, **kwargs):
29
+ super(TSDP, self).__init__()
30
+
31
+ def forward(self, x):
32
+ # The last dimension is the temporal axis
33
+ pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
34
+ pooling_std = pooling_std.flatten(start_dim=1)
35
+ return pooling_std
36
+
37
+
38
+ class TSTP(nn.Module):
39
+ """
40
+ Temporal statistics pooling, concatenate mean and std, which is used in
41
+ x-vector
42
+ Comment: simple concatenation can not make full use of both statistics
43
+ """
44
+ def __init__(self, **kwargs):
45
+ super(TSTP, self).__init__()
46
+
47
+ def forward(self, x):
48
+ # The last dimension is the temporal axis
49
+ pooling_mean = x.mean(dim=-1)
50
+ pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
51
+ pooling_mean = pooling_mean.flatten(start_dim=1)
52
+ pooling_std = pooling_std.flatten(start_dim=1)
53
+
54
+ stats = torch.cat((pooling_mean, pooling_std), 1)
55
+ return stats
56
+
57
+
58
+ class ASTP(nn.Module):
59
+ """ Attentive statistics pooling: Channel- and context-dependent
60
+ statistics pooling, first used in ECAPA_TDNN.
61
+ """
62
+ def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
63
+ super(ASTP, self).__init__()
64
+ self.global_context_att = global_context_att
65
+
66
+ # Use Conv1d with stride == 1 rather than Linear, then we don't
67
+ # need to transpose inputs.
68
+ if global_context_att:
69
+ self.linear1 = nn.Conv1d(
70
+ in_dim * 3, bottleneck_dim,
71
+ kernel_size=1) # equals W and b in the paper
72
+ else:
73
+ self.linear1 = nn.Conv1d(
74
+ in_dim, bottleneck_dim,
75
+ kernel_size=1) # equals W and b in the paper
76
+ self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
77
+ kernel_size=1) # equals V and k in the paper
78
+
79
+ def forward(self, x):
80
+ """
81
+ x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
82
+ or a 4-dimensional tensor in resnet architecture (B,C,F,T)
83
+ 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
84
+ """
85
+ if len(x.shape) == 4:
86
+ x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
87
+ assert len(x.shape) == 3
88
+
89
+ if self.global_context_att:
90
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
91
+ context_std = torch.sqrt(
92
+ torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
93
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
94
+ else:
95
+ x_in = x
96
+
97
+ # DON'T use ReLU here! ReLU may be hard to converge.
98
+ alpha = torch.tanh(
99
+ self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
100
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
101
+ mean = torch.sum(alpha * x, dim=2)
102
+ var = torch.sum(alpha * (x**2), dim=2) - mean**2
103
+ std = torch.sqrt(var.clamp(min=1e-10))
104
+ return torch.cat([mean, std], dim=1)
feature_extractor/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from . import cnhubert, whisper_enc
2
+
3
+ content_module_map = {
4
+ 'cnhubert': cnhubert,
5
+ 'whisper': whisper_enc
6
+ }
feature_extractor/cnhubert.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import librosa
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import soundfile as sf
7
+ import os
8
+ from transformers import logging as tf_logging
9
+ tf_logging.set_verbosity_error()
10
+
11
+ import logging
12
+ logging.getLogger("numba").setLevel(logging.WARNING)
13
+
14
+ from transformers import (
15
+ Wav2Vec2FeatureExtractor,
16
+ HubertModel,
17
+ )
18
+
19
+ import utils
20
+ import torch.nn as nn
21
+
22
+ cnhubert_base_path = None
23
+
24
+
25
+ class CNHubert(nn.Module):
26
+ def __init__(self):
27
+ super().__init__()
28
+ if os.path.exists(cnhubert_base_path):...
29
+ else:raise FileNotFoundError(cnhubert_base_path)
30
+ self.model = HubertModel.from_pretrained(cnhubert_base_path, local_files_only=True)
31
+ self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
32
+ cnhubert_base_path, local_files_only=True
33
+ )
34
+
35
+ def forward(self, x):
36
+ input_values = self.feature_extractor(
37
+ x, return_tensors="pt", sampling_rate=16000
38
+ ).input_values.to(x.device)
39
+ feats = self.model(input_values)["last_hidden_state"]
40
+ return feats
41
+
42
+
43
+ # class CNHubertLarge(nn.Module):
44
+ # def __init__(self):
45
+ # super().__init__()
46
+ # self.model = HubertModel.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-hubert-large")
47
+ # self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-hubert-large")
48
+ # def forward(self, x):
49
+ # input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
50
+ # feats = self.model(input_values)["last_hidden_state"]
51
+ # return feats
52
+ #
53
+ # class CVec(nn.Module):
54
+ # def __init__(self):
55
+ # super().__init__()
56
+ # self.model = HubertModel.from_pretrained("/data/docker/liujing04/vc-webui-big/hubert_base")
57
+ # self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/vc-webui-big/hubert_base")
58
+ # def forward(self, x):
59
+ # input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
60
+ # feats = self.model(input_values)["last_hidden_state"]
61
+ # return feats
62
+ #
63
+ # class cnw2v2base(nn.Module):
64
+ # def __init__(self):
65
+ # super().__init__()
66
+ # self.model = Wav2Vec2Model.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-wav2vec2-base")
67
+ # self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-wav2vec2-base")
68
+ # def forward(self, x):
69
+ # input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
70
+ # feats = self.model(input_values)["last_hidden_state"]
71
+ # return feats
72
+
73
+
74
+ def get_model():
75
+ model = CNHubert()
76
+ model.eval()
77
+ return model
78
+
79
+
80
+ # def get_large_model():
81
+ # model = CNHubertLarge()
82
+ # model.eval()
83
+ # return model
84
+ #
85
+ # def get_model_cvec():
86
+ # model = CVec()
87
+ # model.eval()
88
+ # return model
89
+ #
90
+ # def get_model_cnw2v2base():
91
+ # model = cnw2v2base()
92
+ # model.eval()
93
+ # return model
94
+
95
+
96
+ def get_content(hmodel, wav_16k_tensor):
97
+ with torch.no_grad():
98
+ feats = hmodel(wav_16k_tensor)
99
+ return feats.transpose(1, 2)
100
+
101
+
102
+ if __name__ == "__main__":
103
+ model = get_model()
104
+ src_path = "/Users/Shared/原音频2.wav"
105
+ wav_16k_tensor = utils.load_wav_to_torch_and_resample(src_path, 16000)
106
+ model = model
107
+ wav_16k_tensor = wav_16k_tensor
108
+ feats = get_content(model, wav_16k_tensor)
109
+ print(feats.shape)
feature_extractor/whisper_enc.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def get_model():
5
+ import whisper
6
+
7
+ model = whisper.load_model("small", device="cpu")
8
+
9
+ return model.encoder
10
+
11
+
12
+ def get_content(model=None, wav_16k_tensor=None):
13
+ from whisper import log_mel_spectrogram, pad_or_trim
14
+
15
+ dev = next(model.parameters()).device
16
+ mel = log_mel_spectrogram(wav_16k_tensor).to(dev)[:, :3000]
17
+ # if torch.cuda.is_available():
18
+ # mel = mel.to(torch.float16)
19
+ feature_len = mel.shape[-1] // 2
20
+ assert mel.shape[-1] < 3000, "输入音频过长,只允许输入30以内音频"
21
+ with torch.no_grad():
22
+ feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[
23
+ :1, :feature_len, :
24
+ ].transpose(1, 2)
25
+ return feature
inference_webui.py ADDED
@@ -0,0 +1,867 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import re
4
+ import traceback
5
+ from time import time as ttime
6
+
7
+ import gradio as gr
8
+ import gradio.themes as themes
9
+ import librosa
10
+ import nltk
11
+ import numpy as np
12
+ import spaces
13
+ import torch
14
+ import torchaudio
15
+ from gradio.themes.utils import fonts
16
+ from huggingface_hub import snapshot_download
17
+ from transformers.models.auto.modeling_auto import AutoModelForMaskedLM
18
+ from transformers.models.auto.tokenization_auto import AutoTokenizer
19
+
20
+ from AR.models.structs import T2SRequest
21
+ from AR.models.t2s_model_flash_attn import CUDAGraphRunner
22
+ from feature_extractor import cnhubert
23
+ from module.mel_processing import spectrogram_torch
24
+ from module.models import SynthesizerTrn
25
+ from sv import SV
26
+ from text import chinese, cleaned_text_to_sequence
27
+ from text.cleaner import clean_text
28
+ from text.LangSegmenter import LangSegmenter
29
+ from tools.i18n.i18n import I18nAuto
30
+
31
+ logging.getLogger("markdown_it").setLevel(logging.ERROR)
32
+ logging.getLogger("urllib3").setLevel(logging.ERROR)
33
+ logging.getLogger("httpcore").setLevel(logging.ERROR)
34
+ logging.getLogger("httpx").setLevel(logging.ERROR)
35
+ logging.getLogger("asyncio").setLevel(logging.ERROR)
36
+ logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
37
+ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
38
+ logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
39
+ logging.getLogger("python_multipart.multipart").setLevel(logging.ERROR)
40
+ logging.getLogger("split_lang.split.splitter").setLevel(logging.ERROR)
41
+ logging.getLogger("filelock").setLevel(logging.INFO)
42
+
43
+ os.makedirs("pretrained_models", exist_ok=True)
44
+
45
+ nltk.download("averaged_perceptron_tagger_eng")
46
+
47
+ snapshot_download(
48
+ repo_id="lj1995/GPT-SoVITS",
49
+ repo_type="model",
50
+ allow_patterns="chinese*",
51
+ local_dir="pretrained_models",
52
+ )
53
+ snapshot_download(
54
+ repo_id="lj1995/GPT-SoVITS",
55
+ repo_type="model",
56
+ allow_patterns="s1v3.ckpt",
57
+ local_dir="pretrained_models",
58
+ )
59
+ snapshot_download(
60
+ repo_id="lj1995/GPT-SoVITS",
61
+ repo_type="model",
62
+ allow_patterns="sv*",
63
+ local_dir="pretrained_models",
64
+ )
65
+ snapshot_download(
66
+ repo_id="lj1995/GPT-SoVITS",
67
+ repo_type="model",
68
+ allow_patterns="v2Pro/s2Gv2ProPlus.pth",
69
+ local_dir="pretrained_models",
70
+ )
71
+
72
+ version = "v2" # os.environ.get("version","v2")
73
+ cnhubert_base_path = os.environ.get("cnhubert_base_path", "pretrained_models/chinese-hubert-base")
74
+ bert_path = os.environ.get("bert_path", "pretrained_models/chinese-roberta-wwm-ext-large")
75
+ cnhubert.cnhubert_base_path = cnhubert_base_path
76
+
77
+ punctuation = set(["!", "?", "…", ",", ".", "-", " "])
78
+
79
+
80
+ i18n = I18nAuto(language="Auto")
81
+
82
+ if torch.cuda.is_available():
83
+ device = "cuda"
84
+ is_half = True
85
+ else:
86
+ device = "cpu"
87
+ is_half = False
88
+
89
+ dict_language_v1 = {
90
+ i18n("中文"): "all_zh", # 全部按中文识别
91
+ i18n("英文"): "en", # 全部按英文识别#######不变
92
+ i18n("日文"): "all_ja", # 全部按日文识别
93
+ i18n("中英混合"): "zh", # 按中英混合识别####不变
94
+ i18n("日英混合"): "ja", # 按日英混合识别####不变
95
+ i18n("多语种混合"): "auto", # 多语种启动切分识别语种
96
+ }
97
+ dict_language_v2 = {
98
+ i18n("中文"): "all_zh", # 全部按中文识别
99
+ i18n("英文"): "en", # 全部按英文识别#######不变
100
+ i18n("日文"): "all_ja", # 全部按日文识别
101
+ i18n("粤语"): "all_yue", # 全部按中文识别
102
+ i18n("韩文"): "all_ko", # 全部按韩文识别
103
+ i18n("中英混合"): "zh", # 按中英混合识别####不变
104
+ i18n("日英混合"): "ja", # 按日英混合识别####不变
105
+ i18n("粤英混合"): "yue", # 按粤英混合识别####不变
106
+ i18n("韩英混合"): "ko", # 按韩英混合识别####不变
107
+ i18n("多语种混合"): "auto", # 多语种启动切分识别语种
108
+ i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种
109
+ }
110
+ dict_language = dict_language_v1 if version == "v1" else dict_language_v2
111
+
112
+ tokenizer = AutoTokenizer.from_pretrained(bert_path)
113
+ bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
114
+ if is_half is True:
115
+ bert_model = bert_model.half().to(device)
116
+ else:
117
+ bert_model = bert_model.to(device)
118
+
119
+
120
+ def get_bert_feature(text, word2ph):
121
+ with torch.no_grad():
122
+ inputs = tokenizer(text, return_tensors="pt")
123
+ for i in inputs:
124
+ inputs[i] = inputs[i].to(device)
125
+ res = bert_model(**inputs, output_hidden_states=True)
126
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
127
+ assert len(word2ph) == len(text)
128
+ phone_level_feature = []
129
+ for i in range(len(word2ph)):
130
+ repeat_feature = res[i].repeat(word2ph[i], 1)
131
+ phone_level_feature.append(repeat_feature)
132
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
133
+ return phone_level_feature.T
134
+
135
+
136
+ class DictToAttrRecursive(dict):
137
+ def __init__(self, input_dict):
138
+ super().__init__(input_dict)
139
+ for key, value in input_dict.items():
140
+ if isinstance(value, dict):
141
+ value = DictToAttrRecursive(value)
142
+ self[key] = value
143
+ setattr(self, key, value)
144
+
145
+ def __getattr__(self, item):
146
+ try:
147
+ return self[item]
148
+ except KeyError:
149
+ raise AttributeError(f"Attribute {item} not found")
150
+
151
+ def __setattr__(self, key, value):
152
+ if isinstance(value, dict):
153
+ value = DictToAttrRecursive(value)
154
+ super(DictToAttrRecursive, self).__setitem__(key, value)
155
+ super().__setattr__(key, value)
156
+
157
+ def __delattr__(self, item):
158
+ try:
159
+ del self[item]
160
+ except KeyError:
161
+ raise AttributeError(f"Attribute {item} not found")
162
+
163
+
164
+ ssl_model = cnhubert.get_model()
165
+ if is_half is True:
166
+ ssl_model = ssl_model.half().to(device)
167
+ else:
168
+ ssl_model = ssl_model.to(device)
169
+
170
+
171
+ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
172
+ global vq_model, hps, version, dict_language
173
+ dict_s2 = torch.load(sovits_path, map_location="cpu")
174
+ hps = dict_s2["config"]
175
+ hps = DictToAttrRecursive(hps)
176
+ hps.model.semantic_frame_rate = "25hz"
177
+ if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
178
+ hps.model.version = "v1"
179
+ else:
180
+ hps.model.version = "v2"
181
+ version = hps.model.version
182
+ # print("sovits版本:",hps.model.version)
183
+ vq_model = SynthesizerTrn(
184
+ hps.data.filter_length // 2 + 1,
185
+ hps.train.segment_size // hps.data.hop_length,
186
+ n_speakers=hps.data.n_speakers,
187
+ **hps.model,
188
+ )
189
+ if "pretrained" not in sovits_path:
190
+ del vq_model.enc_q
191
+ if is_half == True:
192
+ vq_model = vq_model.half().to(device)
193
+ else:
194
+ vq_model = vq_model.to(device)
195
+ vq_model.eval()
196
+ print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
197
+ dict_language = dict_language_v1 if version == "v1" else dict_language_v2
198
+ if prompt_language is not None and text_language is not None:
199
+ if prompt_language in list(dict_language.keys()):
200
+ prompt_text_update, prompt_language_update = (
201
+ {"__type__": "update"},
202
+ {"__type__": "update", "value": prompt_language},
203
+ )
204
+ else:
205
+ prompt_text_update = {"__type__": "update", "value": ""}
206
+ prompt_language_update = {"__type__": "update", "value": i18n("中文")}
207
+ if text_language in list(dict_language.keys()):
208
+ text_update, text_language_update = {"__type__": "update"}, {"__type__": "update", "value": text_language}
209
+ else:
210
+ text_update = {"__type__": "update", "value": ""}
211
+ text_language_update = {"__type__": "update", "value": i18n("中文")}
212
+ return (
213
+ {"__type__": "update", "choices": list(dict_language.keys())},
214
+ {"__type__": "update", "choices": list(dict_language.keys())},
215
+ prompt_text_update,
216
+ prompt_language_update,
217
+ text_update,
218
+ text_language_update,
219
+ )
220
+
221
+
222
+ change_sovits_weights("pretrained_models/v2Pro/s2Gv2ProPlus.pth")
223
+
224
+
225
+ def change_gpt_weights(gpt_path):
226
+ global t2s_model, config
227
+ dict_s1 = torch.load(gpt_path, map_location="cpu")
228
+ config = dict_s1["config"]
229
+ t2s_model = CUDAGraphRunner(
230
+ CUDAGraphRunner.load_decoder(gpt_path), torch.device(device), torch.float16 if is_half else torch.float32
231
+ )
232
+ total = sum(p.numel() for p in t2s_model.decoder_model.parameters())
233
+ print("Number of parameter: %.2fM" % (total / 1e6))
234
+
235
+
236
+ change_gpt_weights("pretrained_models/s1v3.ckpt")
237
+
238
+
239
+ sv_cn_model = SV(device, is_half)
240
+
241
+ resample_transform_dict = {}
242
+
243
+
244
+ def resample(audio_tensor, sr0, sr1, device):
245
+ global resample_transform_dict
246
+ key = "%s-%s-%s" % (sr0, sr1, str(device))
247
+ if key not in resample_transform_dict:
248
+ resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
249
+ return resample_transform_dict[key](audio_tensor)
250
+
251
+
252
+ def get_spepc(hps, filename, dtype, device, is_v2pro=False):
253
+ sr1 = int(hps.data.sampling_rate)
254
+ audio, sr0 = torchaudio.load(filename)
255
+ if sr0 != sr1:
256
+ audio = audio.to(device)
257
+ if audio.shape[0] == 2:
258
+ audio = audio.mean(0).unsqueeze(0)
259
+ audio = resample(audio, sr0, sr1, device)
260
+ else:
261
+ audio = audio.to(device)
262
+ if audio.shape[0] == 2:
263
+ audio = audio.mean(0).unsqueeze(0)
264
+
265
+ maxx = audio.abs().max()
266
+ if maxx > 1:
267
+ audio /= min(2, maxx)
268
+ spec = spectrogram_torch(
269
+ audio,
270
+ hps.data.filter_length,
271
+ hps.data.sampling_rate,
272
+ hps.data.hop_length,
273
+ hps.data.win_length,
274
+ center=False,
275
+ )
276
+ spec = spec.to(dtype)
277
+ if is_v2pro is True:
278
+ audio = resample(audio, sr1, 16000, device).to(dtype)
279
+ return spec, audio
280
+
281
+
282
+ def clean_text_inf(text, language, version):
283
+ language = language.replace("all_", "")
284
+ phones, word2ph, norm_text = clean_text(text, language, version)
285
+ phones = cleaned_text_to_sequence(phones, version)
286
+ return phones, word2ph, norm_text
287
+
288
+
289
+ dtype = torch.float16 if is_half is True else torch.float32
290
+
291
+
292
+ def get_bert_inf(phones, word2ph, norm_text, language):
293
+ language = language.replace("all_", "")
294
+ if language == "zh":
295
+ bert = get_bert_feature(norm_text, word2ph).to(device) # .to(dtype)
296
+ else:
297
+ bert = torch.zeros(
298
+ (1024, len(phones)),
299
+ dtype=torch.float16 if is_half is True else torch.float32,
300
+ ).to(device)
301
+
302
+ return bert
303
+
304
+
305
+ splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…"}
306
+
307
+
308
+ def get_first(text):
309
+ pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
310
+ text = re.split(pattern, text)[0].strip()
311
+ return text
312
+
313
+
314
+ def get_phones_and_bert(text, language, version, final=False):
315
+ if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
316
+ formattext = text
317
+ while " " in formattext:
318
+ formattext = formattext.replace(" ", " ")
319
+ if language == "all_zh":
320
+ if re.search(r"[A-Za-z]", formattext):
321
+ formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
322
+ formattext = chinese.mix_text_normalize(formattext)
323
+ return get_phones_and_bert(formattext, "zh", version)
324
+ else:
325
+ phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
326
+ bert = get_bert_feature(norm_text, word2ph).to(device)
327
+ elif language == "all_yue" and re.search(r"[A-Za-z]", formattext):
328
+ formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
329
+ formattext = chinese.mix_text_normalize(formattext)
330
+ return get_phones_and_bert(formattext, "yue", version)
331
+ else:
332
+ phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
333
+ bert = torch.zeros(
334
+ (1024, len(phones)),
335
+ dtype=torch.float16 if is_half is True else torch.float32,
336
+ ).to(device)
337
+ elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
338
+ textlist = []
339
+ langlist = []
340
+ if language == "auto":
341
+ for tmp in LangSegmenter.getTexts(text):
342
+ langlist.append(tmp["lang"])
343
+ textlist.append(tmp["text"])
344
+ elif language == "auto_yue":
345
+ for tmp in LangSegmenter.getTexts(text):
346
+ if tmp["lang"] == "zh":
347
+ tmp["lang"] = "yue"
348
+ langlist.append(tmp["lang"])
349
+ textlist.append(tmp["text"])
350
+ else:
351
+ for tmp in LangSegmenter.getTexts(text):
352
+ if tmp["lang"] == "en":
353
+ langlist.append(tmp["lang"])
354
+ else:
355
+ # 因无法区别中日韩文汉字,以用户输入为准
356
+ langlist.append(language)
357
+ textlist.append(tmp["text"])
358
+ print(textlist)
359
+ print(langlist)
360
+ phones_list = []
361
+ bert_list = []
362
+ norm_text_list = []
363
+ for i in range(len(textlist)):
364
+ lang = langlist[i]
365
+ phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
366
+ bert = get_bert_inf(phones, word2ph, norm_text, lang)
367
+ phones_list.append(phones)
368
+ norm_text_list.append(norm_text)
369
+ bert_list.append(bert)
370
+ bert = torch.cat(bert_list, dim=1)
371
+ phones = sum(phones_list, [])
372
+ norm_text = "".join(norm_text_list)
373
+
374
+ if not final and len(phones) < 6:
375
+ return get_phones_and_bert("." + text, language, version, final=True)
376
+
377
+ return phones, bert.to(dtype), norm_text
378
+
379
+
380
+ def merge_short_text_in_array(texts, threshold):
381
+ if (len(texts)) < 2:
382
+ return texts
383
+ result = []
384
+ text = ""
385
+ for ele in texts:
386
+ text += ele
387
+ if len(text) >= threshold:
388
+ result.append(text)
389
+ text = ""
390
+ if len(text) > 0:
391
+ if len(result) == 0:
392
+ result.append(text)
393
+ else:
394
+ result[len(result) - 1] += text
395
+ return result
396
+
397
+
398
+ ##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature
399
+ # cache_tokens={}#暂未实现清理机制
400
+ cache = {}
401
+
402
+
403
+ @spaces.GPU
404
+ def get_tts_wav(
405
+ ref_wav_path,
406
+ prompt_text,
407
+ prompt_language,
408
+ text,
409
+ text_language,
410
+ how_to_cut=i18n("不切"),
411
+ top_k=20,
412
+ top_p=0.6,
413
+ temperature=0.6,
414
+ ref_free=False,
415
+ speed=1,
416
+ if_freeze=False,
417
+ inp_refs=123,
418
+ ):
419
+ global cache
420
+ if ref_wav_path:
421
+ pass
422
+ else:
423
+ gr.Warning(i18n("请上传参考音频"))
424
+ if text:
425
+ pass
426
+ else:
427
+ gr.Warning(i18n("请填入推理文本"))
428
+ t = []
429
+ if prompt_text is None or len(prompt_text) == 0:
430
+ ref_free = True
431
+ t0 = ttime()
432
+ prompt_language = dict_language[prompt_language]
433
+ text_language = dict_language[text_language]
434
+
435
+ if not ref_free:
436
+ prompt_text = prompt_text.strip("\n")
437
+ if prompt_text[-1] not in splits:
438
+ prompt_text += "。" if prompt_language != "en" else "."
439
+ print(i18n("实际输入的参考文本:"), prompt_text)
440
+ text = text.strip("\n")
441
+ if text[0] not in splits and len(get_first(text)) < 4:
442
+ text = "。" + text if text_language != "en" else "." + text
443
+
444
+ print(i18n("实际输入的目标文本:"), text)
445
+ zero_wav = np.zeros(
446
+ int(hps.data.sampling_rate * 0.3),
447
+ dtype=np.float16 if is_half is True else np.float32,
448
+ )
449
+ if not ref_free:
450
+ with torch.no_grad():
451
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000)
452
+ if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000:
453
+ gr.Warning(i18n("参考音频在3~10秒范围外,请更换!"))
454
+ raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
455
+ wav16k = torch.from_numpy(wav16k)
456
+ zero_wav_torch = torch.from_numpy(zero_wav)
457
+ if is_half is True:
458
+ wav16k = wav16k.half().to(device)
459
+ zero_wav_torch = zero_wav_torch.half().to(device)
460
+ else:
461
+ wav16k = wav16k.to(device)
462
+ zero_wav_torch = zero_wav_torch.to(device)
463
+ wav16k = torch.cat([wav16k, zero_wav_torch])
464
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
465
+ codes = vq_model.extract_latent(ssl_content)
466
+ prompt_semantic = codes[0, 0]
467
+ prompt = prompt_semantic.unsqueeze(0).to(device)
468
+
469
+ t1 = ttime()
470
+ t.append(t1 - t0)
471
+
472
+ if how_to_cut == i18n("凑四句一切"):
473
+ text = cut1(text)
474
+ elif how_to_cut == i18n("凑50字一切"):
475
+ text = cut2(text)
476
+ elif how_to_cut == i18n("按中文句号。切"):
477
+ text = cut3(text)
478
+ elif how_to_cut == i18n("按英文句号.切"):
479
+ text = cut4(text)
480
+ elif how_to_cut == i18n("按标点符号切"):
481
+ text = cut5(text)
482
+ while "\n\n" in text:
483
+ text = text.replace("\n\n", "\n")
484
+ print(i18n("实际输入的目标文本(切句后):"), text)
485
+ texts = text.split("\n")
486
+ texts = process_text(texts)
487
+ texts = merge_short_text_in_array(texts, 5)
488
+ audio_opt = []
489
+ if not ref_free:
490
+ phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
491
+
492
+ infer_speed: list[float] = []
493
+
494
+ for i_text, text in enumerate(texts):
495
+ # 解决输入目标文本的空行导致报错的问题
496
+ if len(text.strip()) == 0:
497
+ continue
498
+ if text[-1] not in splits:
499
+ text += "。" if text_language != "en" else "."
500
+ print(i18n("实际输入的目标文本(每句):"), text)
501
+ phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
502
+ print(i18n("前端处理后的文本(每句):"), norm_text2)
503
+ if not ref_free:
504
+ bert = torch.cat([bert1, bert2], 1)
505
+ all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
506
+ else:
507
+ bert = bert2
508
+ all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
509
+
510
+ bert = bert.to(device).unsqueeze(0)
511
+ all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
512
+
513
+ t2 = ttime()
514
+ # cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature)
515
+ # print(cache.keys(),if_freeze)
516
+ if i_text in cache and if_freeze is True:
517
+ pred_semantic = cache[i_text]
518
+ else:
519
+ with torch.no_grad():
520
+ t2s_request = T2SRequest(
521
+ [all_phoneme_ids.squeeze(0)],
522
+ all_phoneme_len,
523
+ all_phoneme_ids.new_zeros((1, 0)) if ref_free else prompt,
524
+ [bert.squeeze(0)],
525
+ valid_length=1,
526
+ top_k=top_k,
527
+ top_p=top_p,
528
+ temperature=temperature,
529
+ early_stop_num=1500,
530
+ use_cuda_graph=True,
531
+ # debug=True,
532
+ )
533
+ t2s_result = t2s_model.generate(t2s_request)
534
+
535
+ if t2s_result.exception is not None:
536
+ print(t2s_result.traceback)
537
+ raise t2s_result.exception
538
+
539
+ infer_speed.append(t2s_result.infer_speed)
540
+ pred_semantic = t2s_result.result
541
+ assert pred_semantic
542
+ cache[i_text] = pred_semantic
543
+ t3 = ttime()
544
+ refers = []
545
+ sv_emb = []
546
+ if inp_refs:
547
+ for path in inp_refs:
548
+ try:
549
+ refer, audio_tensor = get_spepc(hps, path.name, dtype, device, is_v2pro=True)
550
+ refers.append(refer)
551
+ sv_emb.append(sv_cn_model.compute_embedding3(audio_tensor))
552
+ except:
553
+ traceback.print_exc()
554
+ if len(refers) == 0:
555
+ refers, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device, is_v2pro=True)
556
+ refers = [refers]
557
+ sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)]
558
+ audio = (
559
+ vq_model.decode(
560
+ pred_semantic[0].unsqueeze(0).unsqueeze(0),
561
+ torch.LongTensor(phones2).to(device).unsqueeze(0),
562
+ refers,
563
+ speed=speed,
564
+ sv_emb=sv_emb,
565
+ )
566
+ .detach()
567
+ .cpu()
568
+ .numpy()[0][0]
569
+ )
570
+ max_audio = np.abs(audio).max() # 简单防止16bit爆音
571
+ if max_audio > 1:
572
+ audio /= max_audio
573
+ audio_opt.append(audio)
574
+ audio_opt.append(zero_wav)
575
+ t4 = ttime()
576
+ t.extend([t2 - t1, t3 - t2, t4 - t3])
577
+ t1 = ttime()
578
+ print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
579
+ gr.Info(f"{sum(infer_speed) / len(infer_speed):.2f} Token/s", title="Infer Speed")
580
+ gr.Info("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])), title="Time Stamps")
581
+ yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
582
+
583
+
584
+ def split(todo_text):
585
+ todo_text = todo_text.replace("……", "。").replace("——", ",")
586
+ if todo_text[-1] not in splits:
587
+ todo_text += "。"
588
+ i_split_head = i_split_tail = 0
589
+ len_text = len(todo_text)
590
+ todo_texts = []
591
+ while 1:
592
+ if i_split_head >= len_text:
593
+ break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
594
+ if todo_text[i_split_head] in splits:
595
+ i_split_head += 1
596
+ todo_texts.append(todo_text[i_split_tail:i_split_head])
597
+ i_split_tail = i_split_head
598
+ else:
599
+ i_split_head += 1
600
+ return todo_texts
601
+
602
+
603
+ def cut1(inp):
604
+ inp = inp.strip("\n")
605
+ inps = split(inp)
606
+ split_idx = list(range(0, len(inps), 4))
607
+ split_idx[-1] = None
608
+ if len(split_idx) > 1:
609
+ opts = []
610
+ for idx in range(len(split_idx) - 1):
611
+ opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
612
+ else:
613
+ opts = [inp]
614
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
615
+ return "\n".join(opts)
616
+
617
+
618
+ def cut2(inp):
619
+ inp = inp.strip("\n")
620
+ inps = split(inp)
621
+ if len(inps) < 2:
622
+ return inp
623
+ opts = []
624
+ summ = 0
625
+ tmp_str = ""
626
+ for i in range(len(inps)):
627
+ summ += len(inps[i])
628
+ tmp_str += inps[i]
629
+ if summ > 50:
630
+ summ = 0
631
+ opts.append(tmp_str)
632
+ tmp_str = ""
633
+ if tmp_str != "":
634
+ opts.append(tmp_str)
635
+ # print(opts)
636
+ if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
637
+ opts[-2] = opts[-2] + opts[-1]
638
+ opts = opts[:-1]
639
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
640
+ return "\n".join(opts)
641
+
642
+
643
+ def cut3(inp):
644
+ inp = inp.strip("\n")
645
+ opts = ["%s" % item for item in inp.strip("。").split("。")]
646
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
647
+ return "\n".join(opts)
648
+
649
+
650
+ def cut4(inp):
651
+ inp = inp.strip("\n")
652
+ opts = ["%s" % item for item in inp.strip(".").split(".")]
653
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
654
+ return "\n".join(opts)
655
+
656
+
657
+ # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
658
+ def cut5(inp):
659
+ inp = inp.strip("\n")
660
+ punds = {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}
661
+ mergeitems = []
662
+ items = []
663
+
664
+ for i, char in enumerate(inp):
665
+ if char in punds:
666
+ if char == "." and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
667
+ items.append(char)
668
+ else:
669
+ items.append(char)
670
+ mergeitems.append("".join(items))
671
+ items = []
672
+ else:
673
+ items.append(char)
674
+
675
+ if items:
676
+ mergeitems.append("".join(items))
677
+
678
+ opt = [item for item in mergeitems if not set(item).issubset(punds)]
679
+ return "\n".join(opt)
680
+
681
+
682
+ def custom_sort_key(s):
683
+ # 使用正则表达式提取字符串中的数字部分和非数字部分
684
+ parts = re.split(r"(\d+)", s)
685
+ # 将数字部分转换为整数,非数字部分保持不变
686
+ parts = [int(part) if part.isdigit() else part for part in parts]
687
+ return parts
688
+
689
+
690
+ def process_text(texts):
691
+ _text = []
692
+ if all(text in [None, " ", "\n", ""] for text in texts):
693
+ raise ValueError(i18n("请输入有效文本"))
694
+ for text in texts:
695
+ if text in [None, " ", ""]:
696
+ pass
697
+ else:
698
+ _text.append(text)
699
+ return _text
700
+
701
+
702
+ def html_center(text, label="p"):
703
+ return f"""<div style="text-align: center; margin: 100; padding: 50;">
704
+ <{label} style="margin: 0; padding: 0;">{text}</{label}>
705
+ </div>"""
706
+
707
+
708
+ def html_left(text, label="p"):
709
+ return f"""<div style="text-align: left; margin: 0; padding: 0;">
710
+ <{label} style="margin: 0; padding: 0;">{text}</{label}>
711
+ </div>"""
712
+
713
+
714
+ theme = themes.Soft(
715
+ font=(
716
+ "-apple-system",
717
+ fonts.GoogleFont("Inter"),
718
+ fonts.GoogleFont("Quicksand"),
719
+ "ui-sans-serif",
720
+ "sans-serif",
721
+ )
722
+ )
723
+ theme.block_border_width = "1px"
724
+
725
+ with gr.Blocks(
726
+ title="GPT-SoVITS WebUI",
727
+ theme=theme,
728
+ analytics_enabled=False,
729
+ ) as app:
730
+ gr.Markdown(
731
+ value="""# GPT-SoVITS-ProPlus Zero-shot TTS Demo
732
+ ## https://github.com/RVC-Boss/GPT-SoVITS
733
+ Input 3 to 10s reference audio to guide the time-bre, speed, emotion of voice, and generate the speech you want by input the inference text. <br>
734
+ 输入3至10秒的参考音频来引导待合成语音的音色、语速和情感,然后输入待合成目标文本,生成目标语音. <br>
735
+ Cross-lingual Support: Inference in languages different from the training dataset, currently supporting English, Japanese, Korean and Cantonese.<br>
736
+ 目前支持中日英韩粤跨语种合成。<br>
737
+ This demo is open source under the MIT license. The author does not have any control over it. Users who use the software and distribute the sounds exported by the software are solely responsible. If you do not agree with this clause, you cannot use or reference any codes and files within this demo. <br>
738
+ 本demo以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. 如不认可该条款, 则不能使用或引用该demo内的任何代码和文件.
739
+ """
740
+ )
741
+ gr.Markdown(html_center(i18n("*请上传并填写参考信息"), "h3"))
742
+ with gr.Row(equal_height=True):
743
+ inp_ref = gr.Audio(label=i18n("请上传3~10秒内参考音频,超过会报错!"), type="filepath")
744
+ with gr.Column():
745
+ ref_text_free = gr.Checkbox(
746
+ label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"),
747
+ value=False,
748
+ interactive=True,
749
+ show_label=True,
750
+ )
751
+ prompt_text = gr.Textbox(
752
+ label=i18n("参考音频的文本"),
753
+ value="",
754
+ lines=3,
755
+ max_lines=3,
756
+ info=i18n(
757
+ "使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开。<br>开启后无视填写的参考文本。"
758
+ ),
759
+ )
760
+ prompt_language = gr.Dropdown(
761
+ label=i18n("参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
762
+ )
763
+ inp_refs = gr.File(
764
+ label=i18n(
765
+ "可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。"
766
+ ),
767
+ file_count="multiple",
768
+ )
769
+ gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"), "h3"))
770
+ with gr.Row(equal_height=True):
771
+ with gr.Column():
772
+ text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=26, max_lines=26)
773
+ with gr.Column():
774
+ text_language = gr.Dropdown(
775
+ label=i18n("需要合成的语种") + i18n(".限制范围越小判别效果越好。"),
776
+ choices=list(dict_language.keys()),
777
+ value=i18n("中文"),
778
+ )
779
+ how_to_cut = gr.Dropdown(
780
+ label=i18n("怎么切"),
781
+ choices=[
782
+ i18n("不切"),
783
+ i18n("凑四句一切"),
784
+ i18n("凑50字一切"),
785
+ i18n("按中文句号。切"),
786
+ i18n("按英文句号.切"),
787
+ i18n("按标点符号切"),
788
+ ],
789
+ value=i18n("凑四句一切"),
790
+ interactive=True,
791
+ )
792
+ gr.Markdown(value=html_center(i18n("语速调整,高为更快")))
793
+ if_freeze = gr.Checkbox(
794
+ label=i18n("是否直接对上次合成结果调整语速和音色。防止随机性。"),
795
+ value=False,
796
+ interactive=True,
797
+ show_label=True,
798
+ )
799
+ speed = gr.Slider(minimum=0.6, maximum=1.65, step=0.05, label=i18n("语速"), value=1, interactive=True)
800
+ gr.Markdown(html_center(i18n("GPT采样参数(无参考文本时不要太低。不懂就用默认):")))
801
+ top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=15, interactive=True)
802
+ top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True)
803
+ temperature = gr.Slider(
804
+ minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True
805
+ )
806
+ with gr.Row(equal_height=True):
807
+ inference_button = gr.Button(i18n("合成语音"), variant="primary", size="lg")
808
+ output = gr.Audio(label=i18n("输出的语音"))
809
+
810
+ inference_button.click(
811
+ get_tts_wav,
812
+ [
813
+ inp_ref,
814
+ prompt_text,
815
+ prompt_language,
816
+ text,
817
+ text_language,
818
+ how_to_cut,
819
+ top_k,
820
+ top_p,
821
+ temperature,
822
+ ref_text_free,
823
+ speed,
824
+ if_freeze,
825
+ inp_refs,
826
+ ],
827
+ [output],
828
+ )
829
+
830
+ if __name__ == "__main__":
831
+ import tempfile
832
+ import wave
833
+
834
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_file:
835
+ file_name = temp_file.name
836
+ with wave.open(temp_file, "w") as wav_file:
837
+ channels = 1
838
+ sample_width = 2
839
+ sample_rate = 44100
840
+ duration = 5
841
+ frequency = 440.0
842
+
843
+ t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
844
+ sine_wave = np.sin(2 * np.pi * frequency * t) # Sine Wave
845
+ int_wave = (sine_wave * 32767).astype(np.int16)
846
+
847
+ wav_file.setnchannels(channels) # pylint: disable=no-member
848
+ wav_file.setsampwidth(sample_width) # pylint: disable=no-member
849
+ wav_file.setframerate(sample_rate) # pylint: disable=no-member
850
+ wav_file.writeframes(int_wave.tobytes()) # pylint: disable=no-member
851
+
852
+ gen = get_tts_wav(
853
+ ref_wav_path=file_name,
854
+ prompt_text="",
855
+ prompt_language=i18n("中文"),
856
+ text="犯大吴疆土者,盛必击而破之,犯大吴疆土者,盛必击而破之,犯大吴疆土者,盛必击而破之,犯大吴疆土者,盛必击而破之",
857
+ text_language=i18n("中文"),
858
+ inp_refs=[],
859
+ )
860
+ next(gen)
861
+
862
+ app.queue().launch(
863
+ server_name="0.0.0.0",
864
+ inbrowser=True,
865
+ show_api=False,
866
+ allowed_paths=["/"],
867
+ )
module/__init__.py ADDED
File without changes
module/attentions.py ADDED
@@ -0,0 +1,709 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from module import commons
7
+ from module.modules import LayerNorm
8
+
9
+
10
+ class Encoder(nn.Module):
11
+ def __init__(
12
+ self,
13
+ hidden_channels,
14
+ filter_channels,
15
+ n_heads,
16
+ n_layers,
17
+ kernel_size=1,
18
+ p_dropout=0.0,
19
+ window_size=4,
20
+ isflow=False,
21
+ **kwargs
22
+ ):
23
+ super().__init__()
24
+ self.hidden_channels = hidden_channels
25
+ self.filter_channels = filter_channels
26
+ self.n_heads = n_heads
27
+ self.n_layers = n_layers
28
+ self.kernel_size = kernel_size
29
+ self.p_dropout = p_dropout
30
+ self.window_size = window_size
31
+
32
+ self.drop = nn.Dropout(p_dropout)
33
+ self.attn_layers = nn.ModuleList()
34
+ self.norm_layers_1 = nn.ModuleList()
35
+ self.ffn_layers = nn.ModuleList()
36
+ self.norm_layers_2 = nn.ModuleList()
37
+ for i in range(self.n_layers):
38
+ self.attn_layers.append(
39
+ MultiHeadAttention(
40
+ hidden_channels,
41
+ hidden_channels,
42
+ n_heads,
43
+ p_dropout=p_dropout,
44
+ window_size=window_size,
45
+ )
46
+ )
47
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
48
+ self.ffn_layers.append(
49
+ FFN(
50
+ hidden_channels,
51
+ hidden_channels,
52
+ filter_channels,
53
+ kernel_size,
54
+ p_dropout=p_dropout,
55
+ )
56
+ )
57
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
58
+ if isflow:
59
+ cond_layer = torch.nn.Conv1d(
60
+ kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
61
+ )
62
+ self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
63
+ self.cond_layer = weight_norm_modules(cond_layer, name="weight")
64
+ self.gin_channels = kwargs["gin_channels"]
65
+
66
+ def forward(self, x, x_mask, g=None):
67
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
68
+ x = x * x_mask
69
+ if g is not None:
70
+ g = self.cond_layer(g)
71
+
72
+ for i in range(self.n_layers):
73
+ if g is not None:
74
+ x = self.cond_pre(x)
75
+ cond_offset = i * 2 * self.hidden_channels
76
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
77
+ x = commons.fused_add_tanh_sigmoid_multiply(
78
+ x, g_l, torch.IntTensor([self.hidden_channels])
79
+ )
80
+ y = self.attn_layers[i](x, x, attn_mask)
81
+ y = self.drop(y)
82
+ x = self.norm_layers_1[i](x + y)
83
+
84
+ y = self.ffn_layers[i](x, x_mask)
85
+ y = self.drop(y)
86
+ x = self.norm_layers_2[i](x + y)
87
+ x = x * x_mask
88
+ return x
89
+
90
+
91
+ class Decoder(nn.Module):
92
+ def __init__(
93
+ self,
94
+ hidden_channels,
95
+ filter_channels,
96
+ n_heads,
97
+ n_layers,
98
+ kernel_size=1,
99
+ p_dropout=0.0,
100
+ proximal_bias=False,
101
+ proximal_init=True,
102
+ **kwargs
103
+ ):
104
+ super().__init__()
105
+ self.hidden_channels = hidden_channels
106
+ self.filter_channels = filter_channels
107
+ self.n_heads = n_heads
108
+ self.n_layers = n_layers
109
+ self.kernel_size = kernel_size
110
+ self.p_dropout = p_dropout
111
+ self.proximal_bias = proximal_bias
112
+ self.proximal_init = proximal_init
113
+
114
+ self.drop = nn.Dropout(p_dropout)
115
+ self.self_attn_layers = nn.ModuleList()
116
+ self.norm_layers_0 = nn.ModuleList()
117
+ self.encdec_attn_layers = nn.ModuleList()
118
+ self.norm_layers_1 = nn.ModuleList()
119
+ self.ffn_layers = nn.ModuleList()
120
+ self.norm_layers_2 = nn.ModuleList()
121
+ for i in range(self.n_layers):
122
+ self.self_attn_layers.append(
123
+ MultiHeadAttention(
124
+ hidden_channels,
125
+ hidden_channels,
126
+ n_heads,
127
+ p_dropout=p_dropout,
128
+ proximal_bias=proximal_bias,
129
+ proximal_init=proximal_init,
130
+ )
131
+ )
132
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
133
+ self.encdec_attn_layers.append(
134
+ MultiHeadAttention(
135
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
136
+ )
137
+ )
138
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
139
+ self.ffn_layers.append(
140
+ FFN(
141
+ hidden_channels,
142
+ hidden_channels,
143
+ filter_channels,
144
+ kernel_size,
145
+ p_dropout=p_dropout,
146
+ causal=True,
147
+ )
148
+ )
149
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
150
+
151
+ def forward(self, x, x_mask, h, h_mask):
152
+ """
153
+ x: decoder input
154
+ h: encoder output
155
+ """
156
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
157
+ device=x.device, dtype=x.dtype
158
+ )
159
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
160
+ x = x * x_mask
161
+ for i in range(self.n_layers):
162
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
163
+ y = self.drop(y)
164
+ x = self.norm_layers_0[i](x + y)
165
+
166
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
167
+ y = self.drop(y)
168
+ x = self.norm_layers_1[i](x + y)
169
+
170
+ y = self.ffn_layers[i](x, x_mask)
171
+ y = self.drop(y)
172
+ x = self.norm_layers_2[i](x + y)
173
+ x = x * x_mask
174
+ return x
175
+
176
+
177
+ class MultiHeadAttention(nn.Module):
178
+ def __init__(
179
+ self,
180
+ channels,
181
+ out_channels,
182
+ n_heads,
183
+ p_dropout=0.0,
184
+ window_size=None,
185
+ heads_share=True,
186
+ block_length=None,
187
+ proximal_bias=False,
188
+ proximal_init=False,
189
+ ):
190
+ super().__init__()
191
+ assert channels % n_heads == 0
192
+
193
+ self.channels = channels
194
+ self.out_channels = out_channels
195
+ self.n_heads = n_heads
196
+ self.p_dropout = p_dropout
197
+ self.window_size = window_size
198
+ self.heads_share = heads_share
199
+ self.block_length = block_length
200
+ self.proximal_bias = proximal_bias
201
+ self.proximal_init = proximal_init
202
+ self.attn = None
203
+
204
+ self.k_channels = channels // n_heads
205
+ self.conv_q = nn.Conv1d(channels, channels, 1)
206
+ self.conv_k = nn.Conv1d(channels, channels, 1)
207
+ self.conv_v = nn.Conv1d(channels, channels, 1)
208
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
209
+ self.drop = nn.Dropout(p_dropout)
210
+
211
+ if window_size is not None:
212
+ n_heads_rel = 1 if heads_share else n_heads
213
+ rel_stddev = self.k_channels**-0.5
214
+ self.emb_rel_k = nn.Parameter(
215
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
216
+ * rel_stddev
217
+ )
218
+ self.emb_rel_v = nn.Parameter(
219
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
220
+ * rel_stddev
221
+ )
222
+
223
+ nn.init.xavier_uniform_(self.conv_q.weight)
224
+ nn.init.xavier_uniform_(self.conv_k.weight)
225
+ nn.init.xavier_uniform_(self.conv_v.weight)
226
+ if proximal_init:
227
+ with torch.no_grad():
228
+ self.conv_k.weight.copy_(self.conv_q.weight)
229
+ self.conv_k.bias.copy_(self.conv_q.bias)
230
+
231
+ def forward(self, x, c, attn_mask=None):
232
+ q = self.conv_q(x)
233
+ k = self.conv_k(c)
234
+ v = self.conv_v(c)
235
+
236
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
237
+
238
+ x = self.conv_o(x)
239
+ return x
240
+
241
+ def attention(self, query, key, value, mask=None):
242
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
243
+ b, d, t_s, t_t = (*key.size(), query.size(2))
244
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
245
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
246
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
247
+
248
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
249
+ if self.window_size is not None:
250
+ assert (
251
+ t_s == t_t
252
+ ), "Relative attention is only available for self-attention."
253
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
254
+ rel_logits = self._matmul_with_relative_keys(
255
+ query / math.sqrt(self.k_channels), key_relative_embeddings
256
+ )
257
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
258
+ scores = scores + scores_local
259
+ if self.proximal_bias:
260
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
261
+ scores = scores + self._attention_bias_proximal(t_s).to(
262
+ device=scores.device, dtype=scores.dtype
263
+ )
264
+ if mask is not None:
265
+ scores = scores.masked_fill(mask == 0, -1e4)
266
+ if self.block_length is not None:
267
+ assert (
268
+ t_s == t_t
269
+ ), "Local attention is only available for self-attention."
270
+ block_mask = (
271
+ torch.ones_like(scores)
272
+ .triu(-self.block_length)
273
+ .tril(self.block_length)
274
+ )
275
+ scores = scores.masked_fill(block_mask == 0, -1e4)
276
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
277
+ p_attn = self.drop(p_attn)
278
+ output = torch.matmul(p_attn, value)
279
+ if self.window_size is not None:
280
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
281
+ value_relative_embeddings = self._get_relative_embeddings(
282
+ self.emb_rel_v, t_s
283
+ )
284
+ output = output + self._matmul_with_relative_values(
285
+ relative_weights, value_relative_embeddings
286
+ )
287
+ output = (
288
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
289
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
290
+ return output, p_attn
291
+
292
+ def _matmul_with_relative_values(self, x, y):
293
+ """
294
+ x: [b, h, l, m]
295
+ y: [h or 1, m, d]
296
+ ret: [b, h, l, d]
297
+ """
298
+ ret = torch.matmul(x, y.unsqueeze(0))
299
+ return ret
300
+
301
+ def _matmul_with_relative_keys(self, x, y):
302
+ """
303
+ x: [b, h, l, d]
304
+ y: [h or 1, m, d]
305
+ ret: [b, h, l, m]
306
+ """
307
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
308
+ return ret
309
+
310
+ def _get_relative_embeddings(self, relative_embeddings, length):
311
+ max_relative_position = 2 * self.window_size + 1
312
+ # Pad first before slice to avoid using cond ops.
313
+ pad_length = max(length - (self.window_size + 1), 0)
314
+ slice_start_position = max((self.window_size + 1) - length, 0)
315
+ slice_end_position = slice_start_position + 2 * length - 1
316
+ if pad_length > 0:
317
+ padded_relative_embeddings = F.pad(
318
+ relative_embeddings,
319
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
320
+ )
321
+ else:
322
+ padded_relative_embeddings = relative_embeddings
323
+ used_relative_embeddings = padded_relative_embeddings[
324
+ :, slice_start_position:slice_end_position
325
+ ]
326
+ return used_relative_embeddings
327
+
328
+ def _relative_position_to_absolute_position(self, x):
329
+ """
330
+ x: [b, h, l, 2*l-1]
331
+ ret: [b, h, l, l]
332
+ """
333
+ batch, heads, length, _ = x.size()
334
+ # Concat columns of pad to shift from relative to absolute indexing.
335
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
336
+
337
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
338
+ x_flat = x.view([batch, heads, length * 2 * length])
339
+ x_flat = F.pad(
340
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
341
+ )
342
+
343
+ # Reshape and slice out the padded elements.
344
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
345
+ :, :, :length, length - 1 :
346
+ ]
347
+ return x_final
348
+
349
+ def _absolute_position_to_relative_position(self, x):
350
+ """
351
+ x: [b, h, l, l]
352
+ ret: [b, h, l, 2*l-1]
353
+ """
354
+ batch, heads, length, _ = x.size()
355
+ # padd along column
356
+ x = F.pad(
357
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
358
+ )
359
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
360
+ # add 0's in the beginning that will skew the elements after reshape
361
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
362
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
363
+ return x_final
364
+
365
+ def _attention_bias_proximal(self, length):
366
+ """Bias for self-attention to encourage attention to close positions.
367
+ Args:
368
+ length: an integer scalar.
369
+ Returns:
370
+ a Tensor with shape [1, 1, length, length]
371
+ """
372
+ r = torch.arange(length, dtype=torch.float32)
373
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
374
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
375
+
376
+
377
+ class FFN(nn.Module):
378
+ def __init__(
379
+ self,
380
+ in_channels,
381
+ out_channels,
382
+ filter_channels,
383
+ kernel_size,
384
+ p_dropout=0.0,
385
+ activation=None,
386
+ causal=False,
387
+ ):
388
+ super().__init__()
389
+ self.in_channels = in_channels
390
+ self.out_channels = out_channels
391
+ self.filter_channels = filter_channels
392
+ self.kernel_size = kernel_size
393
+ self.p_dropout = p_dropout
394
+ self.activation = activation
395
+ self.causal = causal
396
+
397
+ if causal:
398
+ self.padding = self._causal_padding
399
+ else:
400
+ self.padding = self._same_padding
401
+
402
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
403
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
404
+ self.drop = nn.Dropout(p_dropout)
405
+
406
+ def forward(self, x, x_mask):
407
+ x = self.conv_1(self.padding(x * x_mask))
408
+ if self.activation == "gelu":
409
+ x = x * torch.sigmoid(1.702 * x)
410
+ else:
411
+ x = torch.relu(x)
412
+ x = self.drop(x)
413
+ x = self.conv_2(self.padding(x * x_mask))
414
+ return x * x_mask
415
+
416
+ def _causal_padding(self, x):
417
+ if self.kernel_size == 1:
418
+ return x
419
+ pad_l = self.kernel_size - 1
420
+ pad_r = 0
421
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
422
+ x = F.pad(x, commons.convert_pad_shape(padding))
423
+ return x
424
+
425
+ def _same_padding(self, x):
426
+ if self.kernel_size == 1:
427
+ return x
428
+ pad_l = (self.kernel_size - 1) // 2
429
+ pad_r = self.kernel_size // 2
430
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
431
+ x = F.pad(x, commons.convert_pad_shape(padding))
432
+ return x
433
+
434
+
435
+ import torch.nn as nn
436
+ from torch.nn.utils import remove_weight_norm, weight_norm
437
+
438
+
439
+ class Depthwise_Separable_Conv1D(nn.Module):
440
+ def __init__(
441
+ self,
442
+ in_channels,
443
+ out_channels,
444
+ kernel_size,
445
+ stride=1,
446
+ padding=0,
447
+ dilation=1,
448
+ bias=True,
449
+ padding_mode="zeros", # TODO: refine this type
450
+ device=None,
451
+ dtype=None,
452
+ ):
453
+ super().__init__()
454
+ self.depth_conv = nn.Conv1d(
455
+ in_channels=in_channels,
456
+ out_channels=in_channels,
457
+ kernel_size=kernel_size,
458
+ groups=in_channels,
459
+ stride=stride,
460
+ padding=padding,
461
+ dilation=dilation,
462
+ bias=bias,
463
+ padding_mode=padding_mode,
464
+ device=device,
465
+ dtype=dtype,
466
+ )
467
+ self.point_conv = nn.Conv1d(
468
+ in_channels=in_channels,
469
+ out_channels=out_channels,
470
+ kernel_size=1,
471
+ bias=bias,
472
+ device=device,
473
+ dtype=dtype,
474
+ )
475
+
476
+ def forward(self, input):
477
+ return self.point_conv(self.depth_conv(input))
478
+
479
+ def weight_norm(self):
480
+ self.depth_conv = weight_norm(self.depth_conv, name="weight")
481
+ self.point_conv = weight_norm(self.point_conv, name="weight")
482
+
483
+ def remove_weight_norm(self):
484
+ self.depth_conv = remove_weight_norm(self.depth_conv, name="weight")
485
+ self.point_conv = remove_weight_norm(self.point_conv, name="weight")
486
+
487
+
488
+ class Depthwise_Separable_TransposeConv1D(nn.Module):
489
+ def __init__(
490
+ self,
491
+ in_channels,
492
+ out_channels,
493
+ kernel_size,
494
+ stride=1,
495
+ padding=0,
496
+ output_padding=0,
497
+ bias=True,
498
+ dilation=1,
499
+ padding_mode="zeros", # TODO: refine this type
500
+ device=None,
501
+ dtype=None,
502
+ ):
503
+ super().__init__()
504
+ self.depth_conv = nn.ConvTranspose1d(
505
+ in_channels=in_channels,
506
+ out_channels=in_channels,
507
+ kernel_size=kernel_size,
508
+ groups=in_channels,
509
+ stride=stride,
510
+ output_padding=output_padding,
511
+ padding=padding,
512
+ dilation=dilation,
513
+ bias=bias,
514
+ padding_mode=padding_mode,
515
+ device=device,
516
+ dtype=dtype,
517
+ )
518
+ self.point_conv = nn.Conv1d(
519
+ in_channels=in_channels,
520
+ out_channels=out_channels,
521
+ kernel_size=1,
522
+ bias=bias,
523
+ device=device,
524
+ dtype=dtype,
525
+ )
526
+
527
+ def forward(self, input):
528
+ return self.point_conv(self.depth_conv(input))
529
+
530
+ def weight_norm(self):
531
+ self.depth_conv = weight_norm(self.depth_conv, name="weight")
532
+ self.point_conv = weight_norm(self.point_conv, name="weight")
533
+
534
+ def remove_weight_norm(self):
535
+ remove_weight_norm(self.depth_conv, name="weight")
536
+ remove_weight_norm(self.point_conv, name="weight")
537
+
538
+
539
+ def weight_norm_modules(module, name="weight", dim=0):
540
+ if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
541
+ module, Depthwise_Separable_TransposeConv1D
542
+ ):
543
+ module.weight_norm()
544
+ return module
545
+ else:
546
+ return weight_norm(module, name, dim)
547
+
548
+
549
+ def remove_weight_norm_modules(module, name="weight"):
550
+ if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
551
+ module, Depthwise_Separable_TransposeConv1D
552
+ ):
553
+ module.remove_weight_norm()
554
+ else:
555
+ remove_weight_norm(module, name)
556
+
557
+
558
+ class FFT(nn.Module):
559
+ def __init__(
560
+ self,
561
+ hidden_channels,
562
+ filter_channels,
563
+ n_heads,
564
+ n_layers=1,
565
+ kernel_size=1,
566
+ p_dropout=0.0,
567
+ proximal_bias=False,
568
+ proximal_init=True,
569
+ isflow=False,
570
+ **kwargs
571
+ ):
572
+ super().__init__()
573
+ self.hidden_channels = hidden_channels
574
+ self.filter_channels = filter_channels
575
+ self.n_heads = n_heads
576
+ self.n_layers = n_layers
577
+ self.kernel_size = kernel_size
578
+ self.p_dropout = p_dropout
579
+ self.proximal_bias = proximal_bias
580
+ self.proximal_init = proximal_init
581
+ if isflow:
582
+ cond_layer = torch.nn.Conv1d(
583
+ kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
584
+ )
585
+ self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
586
+ self.cond_layer = weight_norm_modules(cond_layer, name="weight")
587
+ self.gin_channels = kwargs["gin_channels"]
588
+ self.drop = nn.Dropout(p_dropout)
589
+ self.self_attn_layers = nn.ModuleList()
590
+ self.norm_layers_0 = nn.ModuleList()
591
+ self.ffn_layers = nn.ModuleList()
592
+ self.norm_layers_1 = nn.ModuleList()
593
+ for i in range(self.n_layers):
594
+ self.self_attn_layers.append(
595
+ MultiHeadAttention(
596
+ hidden_channels,
597
+ hidden_channels,
598
+ n_heads,
599
+ p_dropout=p_dropout,
600
+ proximal_bias=proximal_bias,
601
+ proximal_init=proximal_init,
602
+ )
603
+ )
604
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
605
+ self.ffn_layers.append(
606
+ FFN(
607
+ hidden_channels,
608
+ hidden_channels,
609
+ filter_channels,
610
+ kernel_size,
611
+ p_dropout=p_dropout,
612
+ causal=True,
613
+ )
614
+ )
615
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
616
+
617
+ def forward(self, x, x_mask, g=None):
618
+ """
619
+ x: decoder input
620
+ h: encoder output
621
+ """
622
+ if g is not None:
623
+ g = self.cond_layer(g)
624
+
625
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
626
+ device=x.device, dtype=x.dtype
627
+ )
628
+ x = x * x_mask
629
+ for i in range(self.n_layers):
630
+ if g is not None:
631
+ x = self.cond_pre(x)
632
+ cond_offset = i * 2 * self.hidden_channels
633
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
634
+ x = commons.fused_add_tanh_sigmoid_multiply(
635
+ x, g_l, torch.IntTensor([self.hidden_channels])
636
+ )
637
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
638
+ y = self.drop(y)
639
+ x = self.norm_layers_0[i](x + y)
640
+
641
+ y = self.ffn_layers[i](x, x_mask)
642
+ y = self.drop(y)
643
+ x = self.norm_layers_1[i](x + y)
644
+ x = x * x_mask
645
+ return x
646
+
647
+
648
+ class TransformerCouplingLayer(nn.Module):
649
+ def __init__(
650
+ self,
651
+ channels,
652
+ hidden_channels,
653
+ kernel_size,
654
+ n_layers,
655
+ n_heads,
656
+ p_dropout=0,
657
+ filter_channels=0,
658
+ mean_only=False,
659
+ wn_sharing_parameter=None,
660
+ gin_channels=0,
661
+ ):
662
+ assert channels % 2 == 0, "channels should be divisible by 2"
663
+ super().__init__()
664
+ self.channels = channels
665
+ self.hidden_channels = hidden_channels
666
+ self.kernel_size = kernel_size
667
+ self.n_layers = n_layers
668
+ self.half_channels = channels // 2
669
+ self.mean_only = mean_only
670
+
671
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
672
+ self.enc = (
673
+ Encoder(
674
+ hidden_channels,
675
+ filter_channels,
676
+ n_heads,
677
+ n_layers,
678
+ kernel_size,
679
+ p_dropout,
680
+ isflow=True,
681
+ gin_channels=gin_channels,
682
+ )
683
+ if wn_sharing_parameter is None
684
+ else wn_sharing_parameter
685
+ )
686
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
687
+ self.post.weight.data.zero_()
688
+ self.post.bias.data.zero_()
689
+
690
+ def forward(self, x, x_mask, g=None, reverse=False):
691
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
692
+ h = self.pre(x0) * x_mask
693
+ h = self.enc(h, x_mask, g=g)
694
+ stats = self.post(h) * x_mask
695
+ if not self.mean_only:
696
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
697
+ else:
698
+ m = stats
699
+ logs = torch.zeros_like(m)
700
+
701
+ if not reverse:
702
+ x1 = m + x1 * torch.exp(logs) * x_mask
703
+ x = torch.cat([x0, x1], 1)
704
+ logdet = torch.sum(logs, [1, 2])
705
+ return x, logdet
706
+ else:
707
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
708
+ x = torch.cat([x0, x1], 1)
709
+ return x
module/attentions_onnx.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from module import commons
7
+ from module.modules import LayerNorm
8
+
9
+
10
+ class LayerNorm(nn.Module):
11
+ def __init__(self, channels, eps=1e-5):
12
+ super().__init__()
13
+ self.channels = channels
14
+ self.eps = eps
15
+
16
+ self.gamma = nn.Parameter(torch.ones(channels))
17
+ self.beta = nn.Parameter(torch.zeros(channels))
18
+
19
+ def forward(self, x):
20
+ x = x.transpose(1, -1)
21
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
22
+ return x.transpose(1, -1)
23
+
24
+
25
+ @torch.jit.script
26
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
27
+ n_channels_int = n_channels[0]
28
+ in_act = input_a + input_b
29
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
30
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
31
+ acts = t_act * s_act
32
+ return acts
33
+
34
+
35
+ class Encoder(nn.Module):
36
+ def __init__(
37
+ self,
38
+ hidden_channels,
39
+ filter_channels,
40
+ n_heads,
41
+ n_layers,
42
+ kernel_size=1,
43
+ p_dropout=0.0,
44
+ window_size=4,
45
+ isflow=True,
46
+ **kwargs
47
+ ):
48
+ super().__init__()
49
+ self.hidden_channels = hidden_channels
50
+ self.filter_channels = filter_channels
51
+ self.n_heads = n_heads
52
+ self.n_layers = n_layers
53
+ self.kernel_size = kernel_size
54
+ self.p_dropout = p_dropout
55
+ self.window_size = window_size
56
+ # if isflow:
57
+ # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
58
+ # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
59
+ # self.cond_layer = weight_norm(cond_layer, name='weight')
60
+ # self.gin_channels = 256
61
+ self.cond_layer_idx = self.n_layers
62
+ if "gin_channels" in kwargs:
63
+ self.gin_channels = kwargs["gin_channels"]
64
+ if self.gin_channels != 0:
65
+ self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
66
+ # vits2 says 3rd block, so idx is 2 by default
67
+ self.cond_layer_idx = (
68
+ kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
69
+ )
70
+ logging.debug(self.gin_channels, self.cond_layer_idx)
71
+ assert (
72
+ self.cond_layer_idx < self.n_layers
73
+ ), "cond_layer_idx should be less than n_layers"
74
+ self.drop = nn.Dropout(p_dropout)
75
+ self.attn_layers = nn.ModuleList()
76
+ self.norm_layers_1 = nn.ModuleList()
77
+ self.ffn_layers = nn.ModuleList()
78
+ self.norm_layers_2 = nn.ModuleList()
79
+ for i in range(self.n_layers):
80
+ self.attn_layers.append(
81
+ MultiHeadAttention(
82
+ hidden_channels,
83
+ hidden_channels,
84
+ n_heads,
85
+ p_dropout=p_dropout,
86
+ window_size=window_size,
87
+ )
88
+ )
89
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
90
+ self.ffn_layers.append(
91
+ FFN(
92
+ hidden_channels,
93
+ hidden_channels,
94
+ filter_channels,
95
+ kernel_size,
96
+ p_dropout=p_dropout,
97
+ )
98
+ )
99
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
100
+
101
+ def forward(self, x, x_mask, g=None):
102
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
103
+ x = x * x_mask
104
+ for i in range(self.n_layers):
105
+ if i == self.cond_layer_idx and g is not None:
106
+ g = self.spk_emb_linear(g.transpose(1, 2))
107
+ g = g.transpose(1, 2)
108
+ x = x + g
109
+ x = x * x_mask
110
+ y = self.attn_layers[i](x, x, attn_mask)
111
+ y = self.drop(y)
112
+ x = self.norm_layers_1[i](x + y)
113
+
114
+ y = self.ffn_layers[i](x, x_mask)
115
+ y = self.drop(y)
116
+ x = self.norm_layers_2[i](x + y)
117
+ x = x * x_mask
118
+ return x
119
+
120
+
121
+ class MultiHeadAttention(nn.Module):
122
+ def __init__(
123
+ self,
124
+ channels,
125
+ out_channels,
126
+ n_heads,
127
+ p_dropout=0.0,
128
+ window_size=None,
129
+ heads_share=True,
130
+ block_length=None,
131
+ proximal_bias=False,
132
+ proximal_init=False,
133
+ ):
134
+ super().__init__()
135
+ assert channels % n_heads == 0
136
+
137
+ self.channels = channels
138
+ self.out_channels = out_channels
139
+ self.n_heads = n_heads
140
+ self.p_dropout = p_dropout
141
+ self.window_size = window_size
142
+ self.heads_share = heads_share
143
+ self.block_length = block_length
144
+ self.proximal_bias = proximal_bias
145
+ self.proximal_init = proximal_init
146
+ self.attn = None
147
+
148
+ self.k_channels = channels // n_heads
149
+ self.conv_q = nn.Conv1d(channels, channels, 1)
150
+ self.conv_k = nn.Conv1d(channels, channels, 1)
151
+ self.conv_v = nn.Conv1d(channels, channels, 1)
152
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
153
+ self.drop = nn.Dropout(p_dropout)
154
+
155
+ if window_size is not None:
156
+ n_heads_rel = 1 if heads_share else n_heads
157
+ rel_stddev = self.k_channels**-0.5
158
+ self.emb_rel_k = nn.Parameter(
159
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
160
+ * rel_stddev
161
+ )
162
+ self.emb_rel_v = nn.Parameter(
163
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
164
+ * rel_stddev
165
+ )
166
+
167
+ nn.init.xavier_uniform_(self.conv_q.weight)
168
+ nn.init.xavier_uniform_(self.conv_k.weight)
169
+ nn.init.xavier_uniform_(self.conv_v.weight)
170
+ if proximal_init:
171
+ with torch.no_grad():
172
+ self.conv_k.weight.copy_(self.conv_q.weight)
173
+ self.conv_k.bias.copy_(self.conv_q.bias)
174
+
175
+ def forward(self, x, c, attn_mask=None):
176
+ q = self.conv_q(x)
177
+ k = self.conv_k(c)
178
+ v = self.conv_v(c)
179
+
180
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
181
+
182
+ x = self.conv_o(x)
183
+ return x
184
+
185
+ def attention(self, query, key, value, mask=None):
186
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
187
+ b, d, t_s, _ = (*key.size(), query.size(2))
188
+ query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
189
+ key = key.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
190
+ value = value.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
191
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
192
+
193
+ if self.window_size is not None:
194
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
195
+ rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
196
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
197
+ scores = scores + scores_local
198
+
199
+ if mask is not None:
200
+ scores = scores.masked_fill(mask == 0, -1e4)
201
+
202
+ p_attn = F.softmax(scores, dim=-1)
203
+ p_attn = self.drop(p_attn)
204
+ output = torch.matmul(p_attn, value)
205
+
206
+ if self.window_size is not None:
207
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
208
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
209
+ output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
210
+
211
+ output = (output.transpose(2, 3).contiguous().view(b, d, -1))
212
+ return output, p_attn
213
+
214
+ def _matmul_with_relative_values(self, x, y):
215
+ """
216
+ x: [b, h, l, m]
217
+ y: [h or 1, m, d]
218
+ ret: [b, h, l, d]
219
+ """
220
+ ret = torch.matmul(x, y.unsqueeze(0))
221
+ return ret
222
+
223
+ def _matmul_with_relative_keys(self, x, y):
224
+ """
225
+ x: [b, h, l, d]
226
+ y: [h or 1, m, d]
227
+ ret: [b, h, l, m]
228
+ """
229
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
230
+ return ret
231
+
232
+ def _get_relative_embeddings(self, relative_embeddings, length):
233
+ max_relative_position = 2 * self.window_size + 1
234
+ # Pad first before slice to avoid using cond ops.
235
+ pad_l = torch.zeros((1), dtype = torch.int64) + length - (self.window_size + 1)
236
+ pad_s = torch.zeros((1), dtype = torch.int64) + (self.window_size + 1) - length
237
+ pad_length = torch.max(pad_l, other=torch.zeros((1), dtype = torch.int64))
238
+ slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype = torch.int64))
239
+
240
+ slice_end_position = slice_start_position + 2 * length - 1
241
+ padded_relative_embeddings = F.pad(
242
+ relative_embeddings,
243
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
244
+ )
245
+ used_relative_embeddings = padded_relative_embeddings[
246
+ :, slice_start_position:slice_end_position
247
+ ]
248
+ return used_relative_embeddings
249
+
250
+ def _relative_position_to_absolute_position(self, x):
251
+ """
252
+ x: [b, h, l, 2*l-1]
253
+ ret: [b, h, l, l]
254
+ """
255
+ batch, heads, length, _ = x.size()
256
+ # Concat columns of pad to shift from relative to absolute indexing.
257
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
258
+
259
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
260
+ x_flat = x.view([batch, heads, length * 2 * length])
261
+ x_flat = F.pad(
262
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
263
+ )
264
+
265
+ # Reshape and slice out the padded elements.
266
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
267
+ :, :, :length, length - 1 :
268
+ ]
269
+ return x_final
270
+
271
+ def _absolute_position_to_relative_position(self, x):
272
+ """
273
+ x: [b, h, l, l]
274
+ ret: [b, h, l, 2*l-1]
275
+ """
276
+ batch, heads, length, _ = x.size()
277
+ # padd along column
278
+ x = F.pad(
279
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
280
+ )
281
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
282
+ # add 0's in the beginning that will skew the elements after reshape
283
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
284
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
285
+ return x_final
286
+
287
+ def _attention_bias_proximal(self, length):
288
+ """Bias for self-attention to encourage attention to close positions.
289
+ Args:
290
+ length: an integer scalar.
291
+ Returns:
292
+ a Tensor with shape [1, 1, length, length]
293
+ """
294
+ r = torch.arange(length, dtype=torch.float32)
295
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
296
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
297
+
298
+
299
+ class FFN(nn.Module):
300
+ def __init__(
301
+ self,
302
+ in_channels,
303
+ out_channels,
304
+ filter_channels,
305
+ kernel_size,
306
+ p_dropout=0.0,
307
+ activation=None,
308
+ causal=False,
309
+ ):
310
+ super().__init__()
311
+ self.in_channels = in_channels
312
+ self.out_channels = out_channels
313
+ self.filter_channels = filter_channels
314
+ self.kernel_size = kernel_size
315
+ self.p_dropout = p_dropout
316
+ self.activation = activation
317
+ self.causal = causal
318
+
319
+ if causal:
320
+ self.padding = self._causal_padding
321
+ else:
322
+ self.padding = self._same_padding
323
+
324
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
325
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
326
+ self.drop = nn.Dropout(p_dropout)
327
+
328
+ def forward(self, x, x_mask):
329
+ x = self.conv_1(self.padding(x * x_mask))
330
+ if self.activation == "gelu":
331
+ x = x * torch.sigmoid(1.702 * x)
332
+ else:
333
+ x = torch.relu(x)
334
+ x = self.drop(x)
335
+ x = self.conv_2(self.padding(x * x_mask))
336
+ return x * x_mask
337
+
338
+ def _causal_padding(self, x):
339
+ if self.kernel_size == 1:
340
+ return x
341
+ pad_l = self.kernel_size - 1
342
+ pad_r = 0
343
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
344
+ x = F.pad(x, commons.convert_pad_shape(padding))
345
+ return x
346
+
347
+ def _same_padding(self, x):
348
+ if self.kernel_size == 1:
349
+ return x
350
+ pad_l = (self.kernel_size - 1) // 2
351
+ pad_r = self.kernel_size // 2
352
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
353
+ x = F.pad(x, commons.convert_pad_shape(padding))
354
+ return x
module/commons.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def init_weights(m, mean=0.0, std=0.01):
7
+ classname = m.__class__.__name__
8
+ if classname.find("Conv") != -1:
9
+ m.weight.data.normal_(mean, std)
10
+
11
+
12
+ def get_padding(kernel_size, dilation=1):
13
+ return int((kernel_size * dilation - dilation) / 2)
14
+
15
+
16
+ def convert_pad_shape(pad_shape):
17
+ l = pad_shape[::-1]
18
+ pad_shape = [item for sublist in l for item in sublist]
19
+ return pad_shape
20
+
21
+
22
+ def intersperse(lst, item):
23
+ result = [item] * (len(lst) * 2 + 1)
24
+ result[1::2] = lst
25
+ return result
26
+
27
+
28
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
29
+ """KL(P||Q)"""
30
+ kl = (logs_q - logs_p) - 0.5
31
+ kl += (
32
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
33
+ )
34
+ return kl
35
+
36
+
37
+ def rand_gumbel(shape):
38
+ """Sample from the Gumbel distribution, protect from overflows."""
39
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40
+ return -torch.log(-torch.log(uniform_samples))
41
+
42
+
43
+ def rand_gumbel_like(x):
44
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45
+ return g
46
+
47
+
48
+ def slice_segments(x, ids_str, segment_size=4):
49
+ ret = torch.zeros_like(x[:, :, :segment_size])
50
+ for i in range(x.size(0)):
51
+ idx_str = ids_str[i]
52
+ idx_end = idx_str + segment_size
53
+ ret[i] = x[i, :, idx_str:idx_end]
54
+ return ret
55
+
56
+
57
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
58
+ b, d, t = x.size()
59
+ if x_lengths is None:
60
+ x_lengths = t
61
+ ids_str_max = x_lengths - segment_size + 1
62
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
63
+ ret = slice_segments(x, ids_str, segment_size)
64
+ return ret, ids_str
65
+
66
+
67
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
68
+ position = torch.arange(length, dtype=torch.float)
69
+ num_timescales = channels // 2
70
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
71
+ num_timescales - 1
72
+ )
73
+ inv_timescales = min_timescale * torch.exp(
74
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
75
+ )
76
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
77
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
78
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
79
+ signal = signal.view(1, channels, length)
80
+ return signal
81
+
82
+
83
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
84
+ b, channels, length = x.size()
85
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
86
+ return x + signal.to(dtype=x.dtype, device=x.device)
87
+
88
+
89
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
90
+ b, channels, length = x.size()
91
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
92
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
93
+
94
+
95
+ def subsequent_mask(length):
96
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
97
+ return mask
98
+
99
+
100
+ @torch.jit.script
101
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
102
+ n_channels_int = n_channels[0]
103
+ in_act = input_a + input_b
104
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
105
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
106
+ acts = t_act * s_act
107
+ return acts
108
+
109
+
110
+ def convert_pad_shape(pad_shape):
111
+ l = pad_shape[::-1]
112
+ pad_shape = [item for sublist in l for item in sublist]
113
+ return pad_shape
114
+
115
+
116
+ def shift_1d(x):
117
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
118
+ return x
119
+
120
+
121
+ def sequence_mask(length, max_length=None):
122
+ if max_length is None:
123
+ max_length = length.max()
124
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
125
+ return x.unsqueeze(0) < length.unsqueeze(1)
126
+
127
+
128
+ def generate_path(duration, mask):
129
+ """
130
+ duration: [b, 1, t_x]
131
+ mask: [b, 1, t_y, t_x]
132
+ """
133
+ device = duration.device
134
+
135
+ b, _, t_y, t_x = mask.shape
136
+ cum_duration = torch.cumsum(duration, -1)
137
+
138
+ cum_duration_flat = cum_duration.view(b * t_x)
139
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
140
+ path = path.view(b, t_x, t_y)
141
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
142
+ path = path.unsqueeze(1).transpose(2, 3) * mask
143
+ return path
144
+
145
+
146
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
147
+ if isinstance(parameters, torch.Tensor):
148
+ parameters = [parameters]
149
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
150
+ norm_type = float(norm_type)
151
+ if clip_value is not None:
152
+ clip_value = float(clip_value)
153
+
154
+ total_norm = 0
155
+ for p in parameters:
156
+ param_norm = p.grad.data.norm(norm_type)
157
+ total_norm += param_norm.item() ** norm_type
158
+ if clip_value is not None:
159
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
160
+ total_norm = total_norm ** (1.0 / norm_type)
161
+ return total_norm
162
+
163
+
164
+ def squeeze(x, x_mask=None, n_sqz=2):
165
+ b, c, t = x.size()
166
+
167
+ t = (t // n_sqz) * n_sqz
168
+ x = x[:, :, :t]
169
+ x_sqz = x.view(b, c, t // n_sqz, n_sqz)
170
+ x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
171
+
172
+ if x_mask is not None:
173
+ x_mask = x_mask[:, :, n_sqz - 1 :: n_sqz]
174
+ else:
175
+ x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
176
+ return x_sqz * x_mask, x_mask
177
+
178
+
179
+ def unsqueeze(x, x_mask=None, n_sqz=2):
180
+ b, c, t = x.size()
181
+
182
+ x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
183
+ x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
184
+
185
+ if x_mask is not None:
186
+ x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
187
+ else:
188
+ x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
189
+ return x_unsqz * x_mask, x_mask
module/core_vq.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # This implementation is inspired from
8
+ # https://github.com/lucidrains/vector-quantize-pytorch
9
+ # which is released under MIT License. Hereafter, the original license:
10
+ # MIT License
11
+ #
12
+ # Copyright (c) 2020 Phil Wang
13
+ #
14
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ # of this software and associated documentation files (the "Software"), to deal
16
+ # in the Software without restriction, including without limitation the rights
17
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ # copies of the Software, and to permit persons to whom the Software is
19
+ # furnished to do so, subject to the following conditions:
20
+ #
21
+ # The above copyright notice and this permission notice shall be included in all
22
+ # copies or substantial portions of the Software.
23
+ #
24
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
+ # SOFTWARE.
31
+
32
+ """Core vector quantization implementation."""
33
+ import typing as tp
34
+
35
+ from einops import rearrange, repeat
36
+ import torch
37
+ from torch import nn
38
+ import torch.nn.functional as F
39
+ from tqdm import tqdm
40
+
41
+
42
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
43
+ return val if val is not None else d
44
+
45
+
46
+ def ema_inplace(moving_avg, new, decay: float):
47
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
48
+
49
+
50
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
51
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
52
+
53
+
54
+ def uniform_init(*shape: int):
55
+ t = torch.empty(shape)
56
+ nn.init.kaiming_uniform_(t)
57
+ return t
58
+
59
+
60
+ def sample_vectors(samples, num: int):
61
+ num_samples, device = samples.shape[0], samples.device
62
+
63
+ if num_samples >= num:
64
+ indices = torch.randperm(num_samples, device=device)[:num]
65
+ else:
66
+ indices = torch.randint(0, num_samples, (num,), device=device)
67
+
68
+ return samples[indices]
69
+
70
+
71
+ def kmeans(samples, num_clusters: int, num_iters: int = 10):
72
+ dim, dtype = samples.shape[-1], samples.dtype
73
+ max_kmeans_samples = 500
74
+ samples = samples[:max_kmeans_samples, :]
75
+ means = sample_vectors(samples, num_clusters)
76
+
77
+ print("kmeans start ... ")
78
+ for _ in tqdm(range(num_iters)):
79
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
80
+ dists = -(diffs**2).sum(dim=-1)
81
+
82
+ buckets = dists.max(dim=-1).indices
83
+ bins = torch.bincount(buckets, minlength=num_clusters)
84
+ zero_mask = bins == 0
85
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
86
+
87
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
88
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
89
+ new_means = new_means / bins_min_clamped[..., None]
90
+
91
+ means = torch.where(zero_mask[..., None], means, new_means)
92
+
93
+ return means, bins
94
+
95
+
96
+ class EuclideanCodebook(nn.Module):
97
+ """Codebook with Euclidean distance.
98
+ Args:
99
+ dim (int): Dimension.
100
+ codebook_size (int): Codebook size.
101
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
102
+ If set to true, run the k-means algorithm on the first training batch and use
103
+ the learned centroids as initialization.
104
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
105
+ decay (float): Decay for exponential moving average over the codebooks.
106
+ epsilon (float): Epsilon value for numerical stability.
107
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
108
+ that have an exponential moving average cluster size less than the specified threshold with
109
+ randomly selected vector from the current batch.
110
+ """
111
+
112
+ def __init__(
113
+ self,
114
+ dim: int,
115
+ codebook_size: int,
116
+ kmeans_init: int = False,
117
+ kmeans_iters: int = 10,
118
+ decay: float = 0.99,
119
+ epsilon: float = 1e-5,
120
+ threshold_ema_dead_code: int = 2,
121
+ ):
122
+ super().__init__()
123
+ self.decay = decay
124
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
125
+ uniform_init if not kmeans_init else torch.zeros
126
+ )
127
+ embed = init_fn(codebook_size, dim)
128
+
129
+ self.codebook_size = codebook_size
130
+
131
+ self.kmeans_iters = kmeans_iters
132
+ self.epsilon = epsilon
133
+ self.threshold_ema_dead_code = threshold_ema_dead_code
134
+
135
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
136
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
137
+ self.register_buffer("embed", embed)
138
+ self.register_buffer("embed_avg", embed.clone())
139
+
140
+ @torch.jit.ignore
141
+ def init_embed_(self, data):
142
+ if self.inited:
143
+ return
144
+
145
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
146
+ self.embed.data.copy_(embed)
147
+ self.embed_avg.data.copy_(embed.clone())
148
+ self.cluster_size.data.copy_(cluster_size)
149
+ self.inited.data.copy_(torch.Tensor([True]))
150
+ # Make sure all buffers across workers are in sync after initialization
151
+ # broadcast_tensors(self.buffers())
152
+
153
+ def replace_(self, samples, mask):
154
+ modified_codebook = torch.where(
155
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
156
+ )
157
+ self.embed.data.copy_(modified_codebook)
158
+
159
+ def expire_codes_(self, batch_samples):
160
+ if self.threshold_ema_dead_code == 0:
161
+ return
162
+
163
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
164
+ if not torch.any(expired_codes):
165
+ return
166
+
167
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
168
+ self.replace_(batch_samples, mask=expired_codes)
169
+ # broadcast_tensors(self.buffers())
170
+
171
+ def preprocess(self, x):
172
+ x = rearrange(x, "... d -> (...) d")
173
+ return x
174
+
175
+ def quantize(self, x):
176
+ embed = self.embed.t()
177
+ dist = -(
178
+ x.pow(2).sum(1, keepdim=True)
179
+ - 2 * x @ embed
180
+ + embed.pow(2).sum(0, keepdim=True)
181
+ )
182
+ embed_ind = dist.max(dim=-1).indices
183
+ return embed_ind
184
+
185
+ def postprocess_emb(self, embed_ind, shape):
186
+ return embed_ind.view(*shape[:-1])
187
+
188
+ def dequantize(self, embed_ind):
189
+ quantize = F.embedding(embed_ind, self.embed)
190
+ return quantize
191
+
192
+ def encode(self, x):
193
+ shape = x.shape
194
+ # pre-process
195
+ x = self.preprocess(x)
196
+ # quantize
197
+ embed_ind = self.quantize(x)
198
+ # post-process
199
+ embed_ind = self.postprocess_emb(embed_ind, shape)
200
+ return embed_ind
201
+
202
+ def decode(self, embed_ind):
203
+ quantize = self.dequantize(embed_ind)
204
+ return quantize
205
+
206
+ def forward(self, x):
207
+ shape, dtype = x.shape, x.dtype
208
+ x = self.preprocess(x)
209
+
210
+ self.init_embed_(x)
211
+
212
+ embed_ind = self.quantize(x)
213
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
214
+ embed_ind = self.postprocess_emb(embed_ind, shape)
215
+ quantize = self.dequantize(embed_ind)
216
+
217
+ if self.training:
218
+ # We do the expiry of code at that point as buffers are in sync
219
+ # and all the workers will take the same decision.
220
+ self.expire_codes_(x)
221
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
222
+ embed_sum = x.t() @ embed_onehot
223
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
224
+ cluster_size = (
225
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
226
+ * self.cluster_size.sum()
227
+ )
228
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
229
+ self.embed.data.copy_(embed_normalized)
230
+
231
+ return quantize, embed_ind
232
+
233
+
234
+ class VectorQuantization(nn.Module):
235
+ """Vector quantization implementation.
236
+ Currently supports only euclidean distance.
237
+ Args:
238
+ dim (int): Dimension
239
+ codebook_size (int): Codebook size
240
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
241
+ decay (float): Decay for exponential moving average over the codebooks.
242
+ epsilon (float): Epsilon value for numerical stability.
243
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
244
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
245
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
246
+ that have an exponential moving average cluster size less than the specified threshold with
247
+ randomly selected vector from the current batch.
248
+ commitment_weight (float): Weight for commitment loss.
249
+ """
250
+
251
+ def __init__(
252
+ self,
253
+ dim: int,
254
+ codebook_size: int,
255
+ codebook_dim: tp.Optional[int] = None,
256
+ decay: float = 0.99,
257
+ epsilon: float = 1e-5,
258
+ kmeans_init: bool = True,
259
+ kmeans_iters: int = 50,
260
+ threshold_ema_dead_code: int = 2,
261
+ commitment_weight: float = 1.0,
262
+ ):
263
+ super().__init__()
264
+ _codebook_dim: int = default(codebook_dim, dim)
265
+
266
+ requires_projection = _codebook_dim != dim
267
+ self.project_in = (
268
+ nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
269
+ )
270
+ self.project_out = (
271
+ nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
272
+ )
273
+
274
+ self.epsilon = epsilon
275
+ self.commitment_weight = commitment_weight
276
+
277
+ self._codebook = EuclideanCodebook(
278
+ dim=_codebook_dim,
279
+ codebook_size=codebook_size,
280
+ kmeans_init=kmeans_init,
281
+ kmeans_iters=kmeans_iters,
282
+ decay=decay,
283
+ epsilon=epsilon,
284
+ threshold_ema_dead_code=threshold_ema_dead_code,
285
+ )
286
+ self.codebook_size = codebook_size
287
+
288
+ @property
289
+ def codebook(self):
290
+ return self._codebook.embed
291
+
292
+ def encode(self, x):
293
+ x = rearrange(x, "b d n -> b n d")
294
+ x = self.project_in(x)
295
+ embed_in = self._codebook.encode(x)
296
+ return embed_in
297
+
298
+ def decode(self, embed_ind):
299
+ quantize = self._codebook.decode(embed_ind)
300
+ quantize = self.project_out(quantize)
301
+ quantize = rearrange(quantize, "b n d -> b d n")
302
+ return quantize
303
+
304
+ def forward(self, x):
305
+ device = x.device
306
+ x = rearrange(x, "b d n -> b n d")
307
+ x = self.project_in(x)
308
+
309
+ quantize, embed_ind = self._codebook(x)
310
+
311
+ if self.training:
312
+ quantize = x + (quantize - x).detach()
313
+
314
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
315
+
316
+ if self.training:
317
+ if self.commitment_weight > 0:
318
+ commit_loss = F.mse_loss(quantize.detach(), x)
319
+ loss = loss + commit_loss * self.commitment_weight
320
+
321
+ quantize = self.project_out(quantize)
322
+ quantize = rearrange(quantize, "b n d -> b d n")
323
+ return quantize, embed_ind, loss
324
+
325
+
326
+ class ResidualVectorQuantization(nn.Module):
327
+ """Residual vector quantization implementation.
328
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
329
+ """
330
+
331
+ def __init__(self, *, num_quantizers, **kwargs):
332
+ super().__init__()
333
+ self.layers = nn.ModuleList(
334
+ [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
335
+ )
336
+
337
+ def forward(
338
+ self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
339
+ ):
340
+ quantized_out = 0.0
341
+ residual = x
342
+
343
+ all_losses = []
344
+ all_indices = []
345
+ out_quantized = []
346
+
347
+ n_q = n_q or len(self.layers)
348
+
349
+ for i, layer in enumerate(self.layers[:n_q]):
350
+ quantized, indices, loss = layer(residual)
351
+ residual = residual - quantized
352
+ quantized_out = quantized_out + quantized
353
+
354
+ all_indices.append(indices)
355
+ all_losses.append(loss)
356
+ if layers and i in layers:
357
+ out_quantized.append(quantized)
358
+
359
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
360
+ return quantized_out, out_indices, out_losses, out_quantized
361
+
362
+ def encode(
363
+ self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
364
+ ) -> torch.Tensor:
365
+ residual = x
366
+ all_indices = []
367
+ n_q = n_q or len(self.layers)
368
+ st = st or 0
369
+ for layer in self.layers[st:n_q]:
370
+ indices = layer.encode(residual)
371
+ quantized = layer.decode(indices)
372
+ residual = residual - quantized
373
+ all_indices.append(indices)
374
+ out_indices = torch.stack(all_indices)
375
+ return out_indices
376
+
377
+ def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
378
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
379
+ for i, indices in enumerate(q_indices):
380
+ layer = self.layers[st + i]
381
+ quantized = layer.decode(indices)
382
+ quantized_out = quantized_out + quantized
383
+ return quantized_out
module/data_utils.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import logging
3
+ import os
4
+ import random
5
+ import traceback
6
+ import numpy as np
7
+ import torch
8
+ import torch.utils.data
9
+ from tqdm import tqdm
10
+
11
+ from module import commons
12
+ from module.mel_processing import spectrogram_torch
13
+ from text import cleaned_text_to_sequence
14
+ from utils import load_wav_to_torch, load_filepaths_and_text
15
+ import torch.nn.functional as F
16
+ from functools import lru_cache
17
+ import requests
18
+ from scipy.io import wavfile
19
+ from io import BytesIO
20
+ from tools.my_utils import load_audio
21
+ version = os.environ.get('version',None)
22
+ # ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79)
23
+ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
24
+ """
25
+ 1) loads audio, speaker_id, text pairs
26
+ 2) normalizes text and converts them to sequences of integers
27
+ 3) computes spectrograms from audio files.
28
+ """
29
+
30
+ def __init__(self, hparams, val=False):
31
+ exp_dir = hparams.exp_dir
32
+ self.path2 = "%s/2-name2text.txt" % exp_dir
33
+ self.path4 = "%s/4-cnhubert" % exp_dir
34
+ self.path5 = "%s/5-wav32k" % exp_dir
35
+ assert os.path.exists(self.path2)
36
+ assert os.path.exists(self.path4)
37
+ assert os.path.exists(self.path5)
38
+ names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
39
+ names5 = set(os.listdir(self.path5))
40
+ self.phoneme_data = {}
41
+ with open(self.path2, "r", encoding="utf8") as f:
42
+ lines = f.read().strip("\n").split("\n")
43
+
44
+ for line in lines:
45
+ tmp = line.split("\t")
46
+ if (len(tmp) != 4):
47
+ continue
48
+ self.phoneme_data[tmp[0]] = [tmp[1]]
49
+
50
+ self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
51
+ tmp = self.audiopaths_sid_text
52
+ leng = len(tmp)
53
+ min_num = 100
54
+ if (leng < min_num):
55
+ self.audiopaths_sid_text = []
56
+ for _ in range(max(2, int(min_num / leng))):
57
+ self.audiopaths_sid_text += tmp
58
+ self.max_wav_value = hparams.max_wav_value
59
+ self.sampling_rate = hparams.sampling_rate
60
+ self.filter_length = hparams.filter_length
61
+ self.hop_length = hparams.hop_length
62
+ self.win_length = hparams.win_length
63
+ self.sampling_rate = hparams.sampling_rate
64
+ self.val = val
65
+
66
+ random.seed(1234)
67
+ random.shuffle(self.audiopaths_sid_text)
68
+
69
+ print("phoneme_data_len:", len(self.phoneme_data.keys()))
70
+ print("wav_data_len:", len(self.audiopaths_sid_text))
71
+
72
+ audiopaths_sid_text_new = []
73
+ lengths = []
74
+ skipped_phone = 0
75
+ skipped_dur = 0
76
+ for audiopath in tqdm(self.audiopaths_sid_text):
77
+ try:
78
+ phoneme = self.phoneme_data[audiopath][0]
79
+ phoneme = phoneme.split(' ')
80
+ phoneme_ids = cleaned_text_to_sequence(phoneme, version)
81
+ except Exception:
82
+ print(f"{audiopath} not in self.phoneme_data !")
83
+ skipped_phone += 1
84
+ continue
85
+
86
+ size = os.path.getsize("%s/%s" % (self.path5, audiopath))
87
+ duration = size / self.sampling_rate / 2
88
+
89
+ if duration == 0:
90
+ print(f"Zero duration for {audiopath}, skipping...")
91
+ skipped_dur += 1
92
+ continue
93
+
94
+ if 54 > duration > 0.6 or self.val:
95
+ audiopaths_sid_text_new.append([audiopath, phoneme_ids])
96
+ lengths.append(size // (2 * self.hop_length))
97
+ else:
98
+ skipped_dur += 1
99
+ continue
100
+
101
+ print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
102
+ print("total left: ", len(audiopaths_sid_text_new))
103
+ assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
104
+ self.audiopaths_sid_text = audiopaths_sid_text_new
105
+ self.lengths = lengths
106
+
107
+ def get_audio_text_speaker_pair(self, audiopath_sid_text):
108
+ audiopath, phoneme_ids = audiopath_sid_text
109
+ text = torch.FloatTensor(phoneme_ids)
110
+ try:
111
+ spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
112
+ with torch.no_grad():
113
+ ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
114
+ if (ssl.shape[-1] != spec.shape[-1]):
115
+ typee = ssl.dtype
116
+ ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
117
+ ssl.requires_grad = False
118
+ except:
119
+ traceback.print_exc()
120
+ spec = torch.zeros(1025, 100)
121
+ wav = torch.zeros(1, 100 * self.hop_length)
122
+ ssl = torch.zeros(1, 768, 100)
123
+ text = text[-1:]
124
+ print("load audio or ssl error!!!!!!", audiopath)
125
+ return (ssl, spec, wav, text)
126
+
127
+ def get_audio(self, filename):
128
+ audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
129
+ audio = torch.FloatTensor(audio_array) # /32768
130
+ audio_norm = audio
131
+ audio_norm = audio_norm.unsqueeze(0)
132
+ spec = spectrogram_torch(audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length,
133
+ center=False)
134
+ spec = torch.squeeze(spec, 0)
135
+ return spec, audio_norm
136
+
137
+ def get_sid(self, sid):
138
+ sid = torch.LongTensor([int(sid)])
139
+ return sid
140
+
141
+ def __getitem__(self, index):
142
+ # with torch.no_grad():
143
+ return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
144
+
145
+ def __len__(self):
146
+ return len(self.audiopaths_sid_text)
147
+
148
+ def random_slice(self, ssl, wav, mel):
149
+ assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
150
+ "first", ssl.shape, wav.shape)
151
+
152
+ len_mel = mel.shape[1]
153
+ if self.val:
154
+ reference_mel = mel[:, :len_mel // 3]
155
+ return reference_mel, ssl, wav, mel
156
+ dir = random.randint(0, 1)
157
+ sep_point = random.randint(int(len_mel // 3), int(len_mel // 3 * 2))
158
+
159
+ if dir == 0:
160
+ reference_mel = mel[:, :sep_point]
161
+ ssl = ssl[:, :, sep_point:]
162
+ wav2 = wav[:, sep_point * self.hop_length:]
163
+ mel = mel[:, sep_point:]
164
+ else:
165
+ reference_mel = mel[:, sep_point:]
166
+ ssl = ssl[:, :, :sep_point]
167
+ wav2 = wav[:, :sep_point * self.hop_length]
168
+ mel = mel[:, :sep_point]
169
+
170
+ assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
171
+ ssl.shape, wav.shape, wav2.shape, mel.shape, sep_point, self.hop_length, sep_point * self.hop_length, dir)
172
+ return reference_mel, ssl, wav2, mel
173
+
174
+
175
+ class TextAudioSpeakerCollate():
176
+ """ Zero-pads model inputs and targets
177
+ """
178
+
179
+ def __init__(self, return_ids=False):
180
+ self.return_ids = return_ids
181
+
182
+ def __call__(self, batch):
183
+ """Collate's training batch from normalized text, audio and speaker identities
184
+ PARAMS
185
+ ------
186
+ batch: [text_normalized, spec_normalized, wav_normalized, sid]
187
+ """
188
+ # Right zero-pad all one-hot text sequences to max input length
189
+ _, ids_sorted_decreasing = torch.sort(
190
+ torch.LongTensor([x[1].size(1) for x in batch]),
191
+ dim=0, descending=True)
192
+
193
+ max_ssl_len = max([x[0].size(2) for x in batch])
194
+ max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
195
+ max_spec_len = max([x[1].size(1) for x in batch])
196
+ max_spec_len = int(2 * ((max_spec_len // 2) + 1))
197
+ max_wav_len = max([x[2].size(1) for x in batch])
198
+ max_text_len = max([x[3].size(0) for x in batch])
199
+
200
+ ssl_lengths = torch.LongTensor(len(batch))
201
+ spec_lengths = torch.LongTensor(len(batch))
202
+ wav_lengths = torch.LongTensor(len(batch))
203
+ text_lengths = torch.LongTensor(len(batch))
204
+
205
+ spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
206
+ wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
207
+ ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
208
+ text_padded = torch.LongTensor(len(batch), max_text_len)
209
+
210
+ spec_padded.zero_()
211
+ wav_padded.zero_()
212
+ ssl_padded.zero_()
213
+ text_padded.zero_()
214
+
215
+ for i in range(len(ids_sorted_decreasing)):
216
+ row = batch[ids_sorted_decreasing[i]]
217
+
218
+ ssl = row[0]
219
+ ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
220
+ ssl_lengths[i] = ssl.size(2)
221
+
222
+ spec = row[1]
223
+ spec_padded[i, :, :spec.size(1)] = spec
224
+ spec_lengths[i] = spec.size(1)
225
+
226
+ wav = row[2]
227
+ wav_padded[i, :, :wav.size(1)] = wav
228
+ wav_lengths[i] = wav.size(1)
229
+
230
+ text = row[3]
231
+ text_padded[i, :text.size(0)] = text
232
+ text_lengths[i] = text.size(0)
233
+
234
+ return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
235
+
236
+
237
+ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
238
+ """
239
+ Maintain similar input lengths in a batch.
240
+ Length groups are specified by boundaries.
241
+ Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
242
+
243
+ It removes samples which are not included in the boundaries.
244
+ Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
245
+ """
246
+
247
+ def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
248
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
249
+ self.lengths = dataset.lengths
250
+ self.batch_size = batch_size
251
+ self.boundaries = boundaries
252
+
253
+ self.buckets, self.num_samples_per_bucket = self._create_buckets()
254
+ self.total_size = sum(self.num_samples_per_bucket)
255
+ self.num_samples = self.total_size // self.num_replicas
256
+
257
+ def _create_buckets(self):
258
+ buckets = [[] for _ in range(len(self.boundaries) - 1)]
259
+ for i in range(len(self.lengths)):
260
+ length = self.lengths[i]
261
+ idx_bucket = self._bisect(length)
262
+ if idx_bucket != -1:
263
+ buckets[idx_bucket].append(i)
264
+
265
+ i = len(buckets) - 1
266
+ while i >= 0:
267
+ if len(buckets[i]) == 0:
268
+ buckets.pop(i)
269
+ self.boundaries.pop(i + 1)
270
+ i -= 1
271
+
272
+ num_samples_per_bucket = []
273
+ for i in range(len(buckets)):
274
+ len_bucket = len(buckets[i])
275
+ total_batch_size = self.num_replicas * self.batch_size
276
+ rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size
277
+ num_samples_per_bucket.append(len_bucket + rem)
278
+ return buckets, num_samples_per_bucket
279
+
280
+ def __iter__(self):
281
+ g = torch.Generator()
282
+ g.manual_seed(self.epoch)
283
+
284
+ indices = []
285
+ if self.shuffle:
286
+ for bucket in self.buckets:
287
+ indices.append(torch.randperm(len(bucket), generator=g).tolist())
288
+ else:
289
+ for bucket in self.buckets:
290
+ indices.append(list(range(len(bucket))))
291
+
292
+ batches = []
293
+ for i in range(len(self.buckets)):
294
+ bucket = self.buckets[i]
295
+ len_bucket = len(bucket)
296
+ ids_bucket = indices[i]
297
+ num_samples_bucket = self.num_samples_per_bucket[i]
298
+
299
+ rem = num_samples_bucket - len_bucket
300
+ ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)]
301
+
302
+ ids_bucket = ids_bucket[self.rank::self.num_replicas]
303
+
304
+ for j in range(len(ids_bucket) // self.batch_size):
305
+ batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size:(j + 1) * self.batch_size]]
306
+ batches.append(batch)
307
+
308
+ if self.shuffle:
309
+ batch_ids = torch.randperm(len(batches), generator=g).tolist()
310
+ batches = [batches[i] for i in batch_ids]
311
+ self.batches = batches
312
+
313
+ assert len(self.batches) * self.batch_size == self.num_samples
314
+ return iter(self.batches)
315
+
316
+ def _bisect(self, x, lo=0, hi=None):
317
+ if hi is None:
318
+ hi = len(self.boundaries) - 1
319
+
320
+ if hi > lo:
321
+ mid = (hi + lo) // 2
322
+ if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
323
+ return mid
324
+ elif x <= self.boundaries[mid]:
325
+ return self._bisect(x, lo, mid)
326
+ else:
327
+ return self._bisect(x, mid + 1, hi)
328
+ else:
329
+ return -1
330
+
331
+ def __len__(self):
332
+ return self.num_samples // self.batch_size
module/losses.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+
7
+ def feature_loss(fmap_r, fmap_g):
8
+ loss = 0
9
+ for dr, dg in zip(fmap_r, fmap_g):
10
+ for rl, gl in zip(dr, dg):
11
+ rl = rl.float().detach()
12
+ gl = gl.float()
13
+ loss += torch.mean(torch.abs(rl - gl))
14
+
15
+ return loss * 2
16
+
17
+
18
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
19
+ loss = 0
20
+ r_losses = []
21
+ g_losses = []
22
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
23
+ dr = dr.float()
24
+ dg = dg.float()
25
+ r_loss = torch.mean((1 - dr) ** 2)
26
+ g_loss = torch.mean(dg**2)
27
+ loss += r_loss + g_loss
28
+ r_losses.append(r_loss.item())
29
+ g_losses.append(g_loss.item())
30
+
31
+ return loss, r_losses, g_losses
32
+
33
+
34
+ def generator_loss(disc_outputs):
35
+ loss = 0
36
+ gen_losses = []
37
+ for dg in disc_outputs:
38
+ dg = dg.float()
39
+ l = torch.mean((1 - dg) ** 2)
40
+ gen_losses.append(l)
41
+ loss += l
42
+
43
+ return loss, gen_losses
44
+
45
+
46
+ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
47
+ """
48
+ z_p, logs_q: [b, h, t_t]
49
+ m_p, logs_p: [b, h, t_t]
50
+ """
51
+ z_p = z_p.float()
52
+ logs_q = logs_q.float()
53
+ m_p = m_p.float()
54
+ logs_p = logs_p.float()
55
+ z_mask = z_mask.float()
56
+
57
+ kl = logs_p - logs_q - 0.5
58
+ kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
59
+ kl = torch.sum(kl * z_mask)
60
+ l = kl / torch.sum(z_mask)
61
+ return l
62
+
63
+
64
+ def mle_loss(z, m, logs, logdet, mask):
65
+ l = torch.sum(logs) + 0.5 * torch.sum(
66
+ torch.exp(-2 * logs) * ((z - m) ** 2)
67
+ ) # neg normal likelihood w/o the constant term
68
+ l = l - torch.sum(logdet) # log jacobian determinant
69
+ l = l / torch.sum(
70
+ torch.ones_like(z) * mask
71
+ ) # averaging across batch, channel and time axes
72
+ l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
73
+ return l
module/mel_processing.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.data
8
+ import numpy as np
9
+ import librosa
10
+ import librosa.util as librosa_util
11
+ from librosa.util import normalize, pad_center, tiny
12
+ from scipy.signal import get_window
13
+ from scipy.io.wavfile import read
14
+ from librosa.filters import mel as librosa_mel_fn
15
+
16
+ MAX_WAV_VALUE = 32768.0
17
+
18
+
19
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
20
+ """
21
+ PARAMS
22
+ ------
23
+ C: compression factor
24
+ """
25
+ return torch.log(torch.clamp(x, min=clip_val) * C)
26
+
27
+
28
+ def dynamic_range_decompression_torch(x, C=1):
29
+ """
30
+ PARAMS
31
+ ------
32
+ C: compression factor used to compress
33
+ """
34
+ return torch.exp(x) / C
35
+
36
+
37
+ def spectral_normalize_torch(magnitudes):
38
+ output = dynamic_range_compression_torch(magnitudes)
39
+ return output
40
+
41
+
42
+ def spectral_de_normalize_torch(magnitudes):
43
+ output = dynamic_range_decompression_torch(magnitudes)
44
+ return output
45
+
46
+
47
+ mel_basis = {}
48
+ hann_window = {}
49
+
50
+
51
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
52
+ if torch.min(y) < -1.0:
53
+ print("min value is ", torch.min(y))
54
+ if torch.max(y) > 1.0:
55
+ print("max value is ", torch.max(y))
56
+
57
+ global hann_window
58
+ dtype_device = str(y.dtype) + "_" + str(y.device)
59
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
60
+ if wnsize_dtype_device not in hann_window:
61
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
62
+ dtype=y.dtype, device=y.device
63
+ )
64
+
65
+ y = torch.nn.functional.pad(
66
+ y.unsqueeze(1),
67
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
68
+ mode="reflect",
69
+ )
70
+ y = y.squeeze(1)
71
+ spec = torch.stft(
72
+ y,
73
+ n_fft,
74
+ hop_length=hop_size,
75
+ win_length=win_size,
76
+ window=hann_window[wnsize_dtype_device],
77
+ center=center,
78
+ pad_mode="reflect",
79
+ normalized=False,
80
+ onesided=True,
81
+ return_complex=False,
82
+ )
83
+
84
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
85
+ return spec
86
+
87
+
88
+ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
89
+ global mel_basis
90
+ dtype_device = str(spec.dtype) + "_" + str(spec.device)
91
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
92
+ if fmax_dtype_device not in mel_basis:
93
+ mel = librosa_mel_fn(
94
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
95
+ )
96
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
97
+ dtype=spec.dtype, device=spec.device
98
+ )
99
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
100
+ spec = spectral_normalize_torch(spec)
101
+ return spec
102
+
103
+
104
+ def mel_spectrogram_torch(
105
+ y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
106
+ ):
107
+ if torch.min(y) < -1.0:
108
+ print("min value is ", torch.min(y))
109
+ if torch.max(y) > 1.0:
110
+ print("max value is ", torch.max(y))
111
+
112
+ global mel_basis, hann_window
113
+ dtype_device = str(y.dtype) + "_" + str(y.device)
114
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
115
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
116
+ if fmax_dtype_device not in mel_basis:
117
+ mel = librosa_mel_fn(
118
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
119
+ )
120
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
121
+ dtype=y.dtype, device=y.device
122
+ )
123
+ if wnsize_dtype_device not in hann_window:
124
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
125
+ dtype=y.dtype, device=y.device
126
+ )
127
+
128
+ y = torch.nn.functional.pad(
129
+ y.unsqueeze(1),
130
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
131
+ mode="reflect",
132
+ )
133
+ y = y.squeeze(1)
134
+
135
+ spec = torch.stft(
136
+ y,
137
+ n_fft,
138
+ hop_length=hop_size,
139
+ win_length=win_size,
140
+ window=hann_window[wnsize_dtype_device],
141
+ center=center,
142
+ pad_mode="reflect",
143
+ normalized=False,
144
+ onesided=True,
145
+ return_complex=False,
146
+ )
147
+
148
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
149
+
150
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
151
+ spec = spectral_normalize_torch(spec)
152
+
153
+ return spec
module/models.py ADDED
@@ -0,0 +1,1040 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore")
3
+ import copy
4
+ import math
5
+ import os
6
+ import pdb
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+
12
+ from module import commons
13
+ from module import modules
14
+ from module import attentions
15
+
16
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
17
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
18
+ from module.commons import init_weights, get_padding
19
+ from module.mrte_model import MRTE
20
+ from module.quantize import ResidualVectorQuantizer
21
+ # from text import symbols
22
+ from text import symbols as symbols_v1
23
+ from text import symbols2 as symbols_v2
24
+ from torch.cuda.amp import autocast
25
+ import contextlib
26
+
27
+
28
+ class StochasticDurationPredictor(nn.Module):
29
+ def __init__(
30
+ self,
31
+ in_channels,
32
+ filter_channels,
33
+ kernel_size,
34
+ p_dropout,
35
+ n_flows=4,
36
+ gin_channels=0,
37
+ ):
38
+ super().__init__()
39
+ filter_channels = in_channels # it needs to be removed from future version.
40
+ self.in_channels = in_channels
41
+ self.filter_channels = filter_channels
42
+ self.kernel_size = kernel_size
43
+ self.p_dropout = p_dropout
44
+ self.n_flows = n_flows
45
+ self.gin_channels = gin_channels
46
+
47
+ self.log_flow = modules.Log()
48
+ self.flows = nn.ModuleList()
49
+ self.flows.append(modules.ElementwiseAffine(2))
50
+ for i in range(n_flows):
51
+ self.flows.append(
52
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
53
+ )
54
+ self.flows.append(modules.Flip())
55
+
56
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
57
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
58
+ self.post_convs = modules.DDSConv(
59
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
60
+ )
61
+ self.post_flows = nn.ModuleList()
62
+ self.post_flows.append(modules.ElementwiseAffine(2))
63
+ for i in range(4):
64
+ self.post_flows.append(
65
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
66
+ )
67
+ self.post_flows.append(modules.Flip())
68
+
69
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
70
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
71
+ self.convs = modules.DDSConv(
72
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
73
+ )
74
+ if gin_channels != 0:
75
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
76
+
77
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
78
+ x = torch.detach(x)
79
+ x = self.pre(x)
80
+ if g is not None:
81
+ g = torch.detach(g)
82
+ x = x + self.cond(g)
83
+ x = self.convs(x, x_mask)
84
+ x = self.proj(x) * x_mask
85
+
86
+ if not reverse:
87
+ flows = self.flows
88
+ assert w is not None
89
+
90
+ logdet_tot_q = 0
91
+ h_w = self.post_pre(w)
92
+ h_w = self.post_convs(h_w, x_mask)
93
+ h_w = self.post_proj(h_w) * x_mask
94
+ e_q = (
95
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
96
+ * x_mask
97
+ )
98
+ z_q = e_q
99
+ for flow in self.post_flows:
100
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
101
+ logdet_tot_q += logdet_q
102
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
103
+ u = torch.sigmoid(z_u) * x_mask
104
+ z0 = (w - u) * x_mask
105
+ logdet_tot_q += torch.sum(
106
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
107
+ )
108
+ logq = (
109
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
110
+ - logdet_tot_q
111
+ )
112
+
113
+ logdet_tot = 0
114
+ z0, logdet = self.log_flow(z0, x_mask)
115
+ logdet_tot += logdet
116
+ z = torch.cat([z0, z1], 1)
117
+ for flow in flows:
118
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
119
+ logdet_tot = logdet_tot + logdet
120
+ nll = (
121
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
122
+ - logdet_tot
123
+ )
124
+ return nll + logq # [b]
125
+ else:
126
+ flows = list(reversed(self.flows))
127
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
128
+ z = (
129
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
130
+ * noise_scale
131
+ )
132
+ for flow in flows:
133
+ z = flow(z, x_mask, g=x, reverse=reverse)
134
+ z0, z1 = torch.split(z, [1, 1], 1)
135
+ logw = z0
136
+ return logw
137
+
138
+
139
+ class DurationPredictor(nn.Module):
140
+ def __init__(
141
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
142
+ ):
143
+ super().__init__()
144
+
145
+ self.in_channels = in_channels
146
+ self.filter_channels = filter_channels
147
+ self.kernel_size = kernel_size
148
+ self.p_dropout = p_dropout
149
+ self.gin_channels = gin_channels
150
+
151
+ self.drop = nn.Dropout(p_dropout)
152
+ self.conv_1 = nn.Conv1d(
153
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
154
+ )
155
+ self.norm_1 = modules.LayerNorm(filter_channels)
156
+ self.conv_2 = nn.Conv1d(
157
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
158
+ )
159
+ self.norm_2 = modules.LayerNorm(filter_channels)
160
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
161
+
162
+ if gin_channels != 0:
163
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
164
+
165
+ def forward(self, x, x_mask, g=None):
166
+ x = torch.detach(x)
167
+ if g is not None:
168
+ g = torch.detach(g)
169
+ x = x + self.cond(g)
170
+ x = self.conv_1(x * x_mask)
171
+ x = torch.relu(x)
172
+ x = self.norm_1(x)
173
+ x = self.drop(x)
174
+ x = self.conv_2(x * x_mask)
175
+ x = torch.relu(x)
176
+ x = self.norm_2(x)
177
+ x = self.drop(x)
178
+ x = self.proj(x * x_mask)
179
+ return x * x_mask
180
+
181
+
182
+ class TextEncoder(nn.Module):
183
+ def __init__(
184
+ self,
185
+ out_channels,
186
+ hidden_channels,
187
+ filter_channels,
188
+ n_heads,
189
+ n_layers,
190
+ kernel_size,
191
+ p_dropout,
192
+ latent_channels=192,
193
+ version = "v2",
194
+ ):
195
+ super().__init__()
196
+ self.out_channels = out_channels
197
+ self.hidden_channels = hidden_channels
198
+ self.filter_channels = filter_channels
199
+ self.n_heads = n_heads
200
+ self.n_layers = n_layers
201
+ self.kernel_size = kernel_size
202
+ self.p_dropout = p_dropout
203
+ self.latent_channels = latent_channels
204
+ self.version = version
205
+
206
+ self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
207
+
208
+ self.encoder_ssl = attentions.Encoder(
209
+ hidden_channels,
210
+ filter_channels,
211
+ n_heads,
212
+ n_layers // 2,
213
+ kernel_size,
214
+ p_dropout,
215
+ )
216
+
217
+ self.encoder_text = attentions.Encoder(
218
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
219
+ )
220
+
221
+ if self.version == "v1":
222
+ symbols = symbols_v1.symbols
223
+ else:
224
+ symbols = symbols_v2.symbols
225
+ self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
226
+
227
+ self.mrte = MRTE()
228
+
229
+ self.encoder2 = attentions.Encoder(
230
+ hidden_channels,
231
+ filter_channels,
232
+ n_heads,
233
+ n_layers // 2,
234
+ kernel_size,
235
+ p_dropout,
236
+ )
237
+
238
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
239
+
240
+ def forward(self, y, y_lengths, text, text_lengths, ge, speed=1,test=None):
241
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
242
+ y.dtype
243
+ )
244
+
245
+ y = self.ssl_proj(y * y_mask) * y_mask
246
+
247
+ y = self.encoder_ssl(y * y_mask, y_mask)
248
+
249
+ text_mask = torch.unsqueeze(
250
+ commons.sequence_mask(text_lengths, text.size(1)), 1
251
+ ).to(y.dtype)
252
+ if test == 1:
253
+ text[:, :] = 0
254
+ text = self.text_embedding(text).transpose(1, 2)
255
+ text = self.encoder_text(text * text_mask, text_mask)
256
+ y = self.mrte(y, y_mask, text, text_mask, ge)
257
+ y = self.encoder2(y * y_mask, y_mask)
258
+ if(speed!=1):
259
+ y = F.interpolate(y, size=int(y.shape[-1] / speed)+1, mode="linear")
260
+ y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
261
+ stats = self.proj(y) * y_mask
262
+ m, logs = torch.split(stats, self.out_channels, dim=1)
263
+ return y, m, logs, y_mask
264
+
265
+ def extract_latent(self, x):
266
+ x = self.ssl_proj(x)
267
+ quantized, codes, commit_loss, quantized_list = self.quantizer(x)
268
+ return codes.transpose(0, 1)
269
+
270
+ def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
271
+ quantized = self.quantizer.decode(codes)
272
+
273
+ y = self.vq_proj(quantized) * y_mask
274
+ y = self.encoder_ssl(y * y_mask, y_mask)
275
+
276
+ y = self.mrte(y, y_mask, refer, refer_mask, ge)
277
+
278
+ y = self.encoder2(y * y_mask, y_mask)
279
+
280
+ stats = self.proj(y) * y_mask
281
+ m, logs = torch.split(stats, self.out_channels, dim=1)
282
+ return y, m, logs, y_mask, quantized
283
+
284
+
285
+ class ResidualCouplingBlock(nn.Module):
286
+ def __init__(
287
+ self,
288
+ channels,
289
+ hidden_channels,
290
+ kernel_size,
291
+ dilation_rate,
292
+ n_layers,
293
+ n_flows=4,
294
+ gin_channels=0,
295
+ ):
296
+ super().__init__()
297
+ self.channels = channels
298
+ self.hidden_channels = hidden_channels
299
+ self.kernel_size = kernel_size
300
+ self.dilation_rate = dilation_rate
301
+ self.n_layers = n_layers
302
+ self.n_flows = n_flows
303
+ self.gin_channels = gin_channels
304
+
305
+ self.flows = nn.ModuleList()
306
+ for i in range(n_flows):
307
+ self.flows.append(
308
+ modules.ResidualCouplingLayer(
309
+ channels,
310
+ hidden_channels,
311
+ kernel_size,
312
+ dilation_rate,
313
+ n_layers,
314
+ gin_channels=gin_channels,
315
+ mean_only=True,
316
+ )
317
+ )
318
+ self.flows.append(modules.Flip())
319
+
320
+ def forward(self, x, x_mask, g=None, reverse=False):
321
+ if not reverse:
322
+ for flow in self.flows:
323
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
324
+ else:
325
+ for flow in reversed(self.flows):
326
+ x = flow(x, x_mask, g=g, reverse=reverse)
327
+ return x
328
+
329
+
330
+ class PosteriorEncoder(nn.Module):
331
+ def __init__(
332
+ self,
333
+ in_channels,
334
+ out_channels,
335
+ hidden_channels,
336
+ kernel_size,
337
+ dilation_rate,
338
+ n_layers,
339
+ gin_channels=0,
340
+ ):
341
+ super().__init__()
342
+ self.in_channels = in_channels
343
+ self.out_channels = out_channels
344
+ self.hidden_channels = hidden_channels
345
+ self.kernel_size = kernel_size
346
+ self.dilation_rate = dilation_rate
347
+ self.n_layers = n_layers
348
+ self.gin_channels = gin_channels
349
+
350
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
351
+ self.enc = modules.WN(
352
+ hidden_channels,
353
+ kernel_size,
354
+ dilation_rate,
355
+ n_layers,
356
+ gin_channels=gin_channels,
357
+ )
358
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
359
+
360
+ def forward(self, x, x_lengths, g=None):
361
+ if g != None:
362
+ g = g.detach()
363
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
364
+ x.dtype
365
+ )
366
+ x = self.pre(x) * x_mask
367
+ x = self.enc(x, x_mask, g=g)
368
+ stats = self.proj(x) * x_mask
369
+ m, logs = torch.split(stats, self.out_channels, dim=1)
370
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
371
+ return z, m, logs, x_mask
372
+
373
+
374
+ class WNEncoder(nn.Module):
375
+ def __init__(
376
+ self,
377
+ in_channels,
378
+ out_channels,
379
+ hidden_channels,
380
+ kernel_size,
381
+ dilation_rate,
382
+ n_layers,
383
+ gin_channels=0,
384
+ ):
385
+ super().__init__()
386
+ self.in_channels = in_channels
387
+ self.out_channels = out_channels
388
+ self.hidden_channels = hidden_channels
389
+ self.kernel_size = kernel_size
390
+ self.dilation_rate = dilation_rate
391
+ self.n_layers = n_layers
392
+ self.gin_channels = gin_channels
393
+
394
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
395
+ self.enc = modules.WN(
396
+ hidden_channels,
397
+ kernel_size,
398
+ dilation_rate,
399
+ n_layers,
400
+ gin_channels=gin_channels,
401
+ )
402
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
403
+ self.norm = modules.LayerNorm(out_channels)
404
+
405
+ def forward(self, x, x_lengths, g=None):
406
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
407
+ x.dtype
408
+ )
409
+ x = self.pre(x) * x_mask
410
+ x = self.enc(x, x_mask, g=g)
411
+ out = self.proj(x) * x_mask
412
+ out = self.norm(out)
413
+ return out
414
+
415
+
416
+ class Generator(torch.nn.Module):
417
+ def __init__(
418
+ self,
419
+ initial_channel,
420
+ resblock,
421
+ resblock_kernel_sizes,
422
+ resblock_dilation_sizes,
423
+ upsample_rates,
424
+ upsample_initial_channel,
425
+ upsample_kernel_sizes,
426
+ gin_channels=0,
427
+ ):
428
+ super(Generator, self).__init__()
429
+ self.num_kernels = len(resblock_kernel_sizes)
430
+ self.num_upsamples = len(upsample_rates)
431
+ self.conv_pre = Conv1d(
432
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
433
+ )
434
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
435
+
436
+ self.ups = nn.ModuleList()
437
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
438
+ self.ups.append(
439
+ weight_norm(
440
+ ConvTranspose1d(
441
+ upsample_initial_channel // (2**i),
442
+ upsample_initial_channel // (2 ** (i + 1)),
443
+ k,
444
+ u,
445
+ padding=(k - u) // 2,
446
+ )
447
+ )
448
+ )
449
+
450
+ self.resblocks = nn.ModuleList()
451
+ for i in range(len(self.ups)):
452
+ ch = upsample_initial_channel // (2 ** (i + 1))
453
+ for j, (k, d) in enumerate(
454
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
455
+ ):
456
+ self.resblocks.append(resblock(ch, k, d))
457
+
458
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
459
+ self.ups.apply(init_weights)
460
+
461
+ if gin_channels != 0:
462
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
463
+
464
+ def forward(self, x, g=None):
465
+ x = self.conv_pre(x)
466
+ if g is not None:
467
+ x = x + self.cond(g)
468
+
469
+ for i in range(self.num_upsamples):
470
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
471
+ x = self.ups[i](x)
472
+ xs = None
473
+ for j in range(self.num_kernels):
474
+ if xs is None:
475
+ xs = self.resblocks[i * self.num_kernels + j](x)
476
+ else:
477
+ xs += self.resblocks[i * self.num_kernels + j](x)
478
+ x = xs / self.num_kernels
479
+ x = F.leaky_relu(x)
480
+ x = self.conv_post(x)
481
+ x = torch.tanh(x)
482
+
483
+ return x
484
+
485
+ def remove_weight_norm(self):
486
+ print("Removing weight norm...")
487
+ for l in self.ups:
488
+ remove_weight_norm(l)
489
+ for l in self.resblocks:
490
+ l.remove_weight_norm()
491
+
492
+
493
+ class DiscriminatorP(torch.nn.Module):
494
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
495
+ super(DiscriminatorP, self).__init__()
496
+ self.period = period
497
+ self.use_spectral_norm = use_spectral_norm
498
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
499
+ self.convs = nn.ModuleList(
500
+ [
501
+ norm_f(
502
+ Conv2d(
503
+ 1,
504
+ 32,
505
+ (kernel_size, 1),
506
+ (stride, 1),
507
+ padding=(get_padding(kernel_size, 1), 0),
508
+ )
509
+ ),
510
+ norm_f(
511
+ Conv2d(
512
+ 32,
513
+ 128,
514
+ (kernel_size, 1),
515
+ (stride, 1),
516
+ padding=(get_padding(kernel_size, 1), 0),
517
+ )
518
+ ),
519
+ norm_f(
520
+ Conv2d(
521
+ 128,
522
+ 512,
523
+ (kernel_size, 1),
524
+ (stride, 1),
525
+ padding=(get_padding(kernel_size, 1), 0),
526
+ )
527
+ ),
528
+ norm_f(
529
+ Conv2d(
530
+ 512,
531
+ 1024,
532
+ (kernel_size, 1),
533
+ (stride, 1),
534
+ padding=(get_padding(kernel_size, 1), 0),
535
+ )
536
+ ),
537
+ norm_f(
538
+ Conv2d(
539
+ 1024,
540
+ 1024,
541
+ (kernel_size, 1),
542
+ 1,
543
+ padding=(get_padding(kernel_size, 1), 0),
544
+ )
545
+ ),
546
+ ]
547
+ )
548
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
549
+
550
+ def forward(self, x):
551
+ fmap = []
552
+
553
+ # 1d to 2d
554
+ b, c, t = x.shape
555
+ if t % self.period != 0: # pad first
556
+ n_pad = self.period - (t % self.period)
557
+ x = F.pad(x, (0, n_pad), "reflect")
558
+ t = t + n_pad
559
+ x = x.view(b, c, t // self.period, self.period)
560
+
561
+ for l in self.convs:
562
+ x = l(x)
563
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
564
+ fmap.append(x)
565
+ x = self.conv_post(x)
566
+ fmap.append(x)
567
+ x = torch.flatten(x, 1, -1)
568
+
569
+ return x, fmap
570
+
571
+
572
+ class DiscriminatorS(torch.nn.Module):
573
+ def __init__(self, use_spectral_norm=False):
574
+ super(DiscriminatorS, self).__init__()
575
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
576
+ self.convs = nn.ModuleList(
577
+ [
578
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
579
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
580
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
581
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
582
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
583
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
584
+ ]
585
+ )
586
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
587
+
588
+ def forward(self, x):
589
+ fmap = []
590
+
591
+ for l in self.convs:
592
+ x = l(x)
593
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
594
+ fmap.append(x)
595
+ x = self.conv_post(x)
596
+ fmap.append(x)
597
+ x = torch.flatten(x, 1, -1)
598
+
599
+ return x, fmap
600
+
601
+
602
+ class MultiPeriodDiscriminator(torch.nn.Module):
603
+ def __init__(self, use_spectral_norm=False):
604
+ super(MultiPeriodDiscriminator, self).__init__()
605
+ periods = [2, 3, 5, 7, 11]
606
+
607
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
608
+ discs = discs + [
609
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
610
+ ]
611
+ self.discriminators = nn.ModuleList(discs)
612
+
613
+ def forward(self, y, y_hat):
614
+ y_d_rs = []
615
+ y_d_gs = []
616
+ fmap_rs = []
617
+ fmap_gs = []
618
+ for i, d in enumerate(self.discriminators):
619
+ y_d_r, fmap_r = d(y)
620
+ y_d_g, fmap_g = d(y_hat)
621
+ y_d_rs.append(y_d_r)
622
+ y_d_gs.append(y_d_g)
623
+ fmap_rs.append(fmap_r)
624
+ fmap_gs.append(fmap_g)
625
+
626
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
627
+
628
+
629
+ class ReferenceEncoder(nn.Module):
630
+ """
631
+ inputs --- [N, Ty/r, n_mels*r] mels
632
+ outputs --- [N, ref_enc_gru_size]
633
+ """
634
+
635
+ def __init__(self, spec_channels, gin_channels=0):
636
+ super().__init__()
637
+ self.spec_channels = spec_channels
638
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
639
+ K = len(ref_enc_filters)
640
+ filters = [1] + ref_enc_filters
641
+ convs = [
642
+ weight_norm(
643
+ nn.Conv2d(
644
+ in_channels=filters[i],
645
+ out_channels=filters[i + 1],
646
+ kernel_size=(3, 3),
647
+ stride=(2, 2),
648
+ padding=(1, 1),
649
+ )
650
+ )
651
+ for i in range(K)
652
+ ]
653
+ self.convs = nn.ModuleList(convs)
654
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
655
+
656
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
657
+ self.gru = nn.GRU(
658
+ input_size=ref_enc_filters[-1] * out_channels,
659
+ hidden_size=256 // 2,
660
+ batch_first=True,
661
+ )
662
+ self.proj = nn.Linear(128, gin_channels)
663
+
664
+ def forward(self, inputs):
665
+ N = inputs.size(0)
666
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
667
+ for conv in self.convs:
668
+ out = conv(out)
669
+ # out = wn(out)
670
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
671
+
672
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
673
+ T = out.size(1)
674
+ N = out.size(0)
675
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
676
+
677
+ self.gru.flatten_parameters()
678
+ memory, out = self.gru(out) # out --- [1, N, 128]
679
+
680
+ return self.proj(out.squeeze(0)).unsqueeze(-1)
681
+
682
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
683
+ for i in range(n_convs):
684
+ L = (L - kernel_size + 2 * pad) // stride + 1
685
+ return L
686
+
687
+
688
+ class Quantizer_module(torch.nn.Module):
689
+ def __init__(self, n_e, e_dim):
690
+ super(Quantizer_module, self).__init__()
691
+ self.embedding = nn.Embedding(n_e, e_dim)
692
+ self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e)
693
+
694
+ def forward(self, x):
695
+ d = (
696
+ torch.sum(x**2, 1, keepdim=True)
697
+ + torch.sum(self.embedding.weight**2, 1)
698
+ - 2 * torch.matmul(x, self.embedding.weight.T)
699
+ )
700
+ min_indicies = torch.argmin(d, 1)
701
+ z_q = self.embedding(min_indicies)
702
+ return z_q, min_indicies
703
+
704
+
705
+ class Quantizer(torch.nn.Module):
706
+ def __init__(self, embed_dim=512, n_code_groups=4, n_codes=160):
707
+ super(Quantizer, self).__init__()
708
+ assert embed_dim % n_code_groups == 0
709
+ self.quantizer_modules = nn.ModuleList(
710
+ [
711
+ Quantizer_module(n_codes, embed_dim // n_code_groups)
712
+ for _ in range(n_code_groups)
713
+ ]
714
+ )
715
+ self.n_code_groups = n_code_groups
716
+ self.embed_dim = embed_dim
717
+
718
+ def forward(self, xin):
719
+ # B, C, T
720
+ B, C, T = xin.shape
721
+ xin = xin.transpose(1, 2)
722
+ x = xin.reshape(-1, self.embed_dim)
723
+ x = torch.split(x, self.embed_dim // self.n_code_groups, dim=-1)
724
+ min_indicies = []
725
+ z_q = []
726
+ for _x, m in zip(x, self.quantizer_modules):
727
+ _z_q, _min_indicies = m(_x)
728
+ z_q.append(_z_q)
729
+ min_indicies.append(_min_indicies) # B * T,
730
+ z_q = torch.cat(z_q, -1).reshape(xin.shape)
731
+ loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
732
+ (z_q - xin.detach()) ** 2
733
+ )
734
+ z_q = xin + (z_q - xin).detach()
735
+ z_q = z_q.transpose(1, 2)
736
+ codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
737
+ return z_q, loss, codes.transpose(1, 2)
738
+
739
+ def embed(self, x):
740
+ # idx: N, 4, T
741
+ x = x.transpose(1, 2)
742
+ x = torch.split(x, 1, 2)
743
+ ret = []
744
+ for q, embed in zip(x, self.quantizer_modules):
745
+ q = embed.embedding(q.squeeze(-1))
746
+ ret.append(q)
747
+ ret = torch.cat(ret, -1)
748
+ return ret.transpose(1, 2) # N, C, T
749
+
750
+
751
+ class CodePredictor(nn.Module):
752
+ def __init__(
753
+ self,
754
+ hidden_channels,
755
+ filter_channels,
756
+ n_heads,
757
+ n_layers,
758
+ kernel_size,
759
+ p_dropout,
760
+ n_q=8,
761
+ dims=1024,
762
+ ssl_dim=768,
763
+ ):
764
+ super().__init__()
765
+ self.hidden_channels = hidden_channels
766
+ self.filter_channels = filter_channels
767
+ self.n_heads = n_heads
768
+ self.n_layers = n_layers
769
+ self.kernel_size = kernel_size
770
+ self.p_dropout = p_dropout
771
+
772
+ self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
773
+ self.ref_enc = modules.MelStyleEncoder(
774
+ ssl_dim, style_vector_dim=hidden_channels
775
+ )
776
+
777
+ self.encoder = attentions.Encoder(
778
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
779
+ )
780
+
781
+ self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
782
+ self.n_q = n_q
783
+ self.dims = dims
784
+
785
+ def forward(self, x, x_mask, refer, codes, infer=False):
786
+ x = x.detach()
787
+ x = self.vq_proj(x * x_mask) * x_mask
788
+ g = self.ref_enc(refer, x_mask)
789
+ x = x + g
790
+ x = self.encoder(x * x_mask, x_mask)
791
+ x = self.out_proj(x * x_mask) * x_mask
792
+ logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
793
+ 2, 3
794
+ )
795
+ target = codes[1:].transpose(0, 1)
796
+ if not infer:
797
+ logits = logits.reshape(-1, self.dims)
798
+ target = target.reshape(-1)
799
+ loss = torch.nn.functional.cross_entropy(logits, target)
800
+ return loss
801
+ else:
802
+ _, top10_preds = torch.topk(logits, 10, dim=-1)
803
+ correct_top10 = torch.any(top10_preds == target.unsqueeze(-1), dim=-1)
804
+ top3_acc = 100 * torch.mean(correct_top10.float()).detach().cpu().item()
805
+
806
+ print("Top-10 Accuracy:", top3_acc, "%")
807
+
808
+ pred_codes = torch.argmax(logits, dim=-1)
809
+ acc = 100 * torch.mean((pred_codes == target).float()).detach().cpu().item()
810
+ print("Top-1 Accuracy:", acc, "%")
811
+
812
+ return pred_codes.transpose(0, 1)
813
+
814
+
815
+ class SynthesizerTrn(nn.Module):
816
+ """
817
+ Synthesizer for Training
818
+ """
819
+
820
+ def __init__(
821
+ self,
822
+ spec_channels,
823
+ segment_size,
824
+ inter_channels,
825
+ hidden_channels,
826
+ filter_channels,
827
+ n_heads,
828
+ n_layers,
829
+ kernel_size,
830
+ p_dropout,
831
+ resblock,
832
+ resblock_kernel_sizes,
833
+ resblock_dilation_sizes,
834
+ upsample_rates,
835
+ upsample_initial_channel,
836
+ upsample_kernel_sizes,
837
+ n_speakers=0,
838
+ gin_channels=0,
839
+ use_sdp=True,
840
+ semantic_frame_rate=None,
841
+ freeze_quantizer=None,
842
+ version = "v2",
843
+ **kwargs
844
+ ):
845
+ super().__init__()
846
+ self.spec_channels = spec_channels
847
+ self.inter_channels = inter_channels
848
+ self.hidden_channels = hidden_channels
849
+ self.filter_channels = filter_channels
850
+ self.n_heads = n_heads
851
+ self.n_layers = n_layers
852
+ self.kernel_size = kernel_size
853
+ self.p_dropout = p_dropout
854
+ self.resblock = resblock
855
+ self.resblock_kernel_sizes = resblock_kernel_sizes
856
+ self.resblock_dilation_sizes = resblock_dilation_sizes
857
+ self.upsample_rates = upsample_rates
858
+ self.upsample_initial_channel = upsample_initial_channel
859
+ self.upsample_kernel_sizes = upsample_kernel_sizes
860
+ self.segment_size = segment_size
861
+ self.n_speakers = n_speakers
862
+ self.gin_channels = gin_channels
863
+ self.version = version
864
+
865
+ self.use_sdp = use_sdp
866
+ self.enc_p = TextEncoder(
867
+ inter_channels,
868
+ hidden_channels,
869
+ filter_channels,
870
+ n_heads,
871
+ n_layers,
872
+ kernel_size,
873
+ p_dropout,
874
+ version = version,
875
+ )
876
+ self.dec = Generator(
877
+ inter_channels,
878
+ resblock,
879
+ resblock_kernel_sizes,
880
+ resblock_dilation_sizes,
881
+ upsample_rates,
882
+ upsample_initial_channel,
883
+ upsample_kernel_sizes,
884
+ gin_channels=gin_channels,
885
+ )
886
+ self.enc_q = PosteriorEncoder(
887
+ spec_channels,
888
+ inter_channels,
889
+ hidden_channels,
890
+ 5,
891
+ 1,
892
+ 16,
893
+ gin_channels=gin_channels,
894
+ )
895
+ self.flow = ResidualCouplingBlock(
896
+ inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
897
+ )
898
+
899
+ # self.version=os.environ.get("version","v1")
900
+ if(self.version=="v1"):
901
+ self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
902
+ else:
903
+ self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)
904
+
905
+ ssl_dim = 768
906
+ assert semantic_frame_rate in ["25hz", "50hz"]
907
+ self.semantic_frame_rate = semantic_frame_rate
908
+ if semantic_frame_rate == "25hz":
909
+ self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
910
+ else:
911
+ self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
912
+
913
+ self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
914
+ self.freeze_quantizer = freeze_quantizer
915
+ self.sv_emb = nn.Linear(20480, gin_channels)
916
+ self.ge_to512 = nn.Linear(gin_channels, 512)
917
+ self.prelu = nn.PReLU(num_parameters=gin_channels)
918
+
919
+ def forward(self, ssl, y, y_lengths, text, text_lengths):
920
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
921
+ y.dtype
922
+ )
923
+ if(self.version=="v1"):
924
+ ge = self.ref_enc(y * y_mask, y_mask)
925
+ else:
926
+ ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
927
+ sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
928
+ ge += sv_emb.unsqueeze(-1)
929
+ ge = self.prelu(ge)
930
+ ge512 = self.ge_to512(ge.transpose(2, 1)).transpose(2, 1)
931
+ with autocast(enabled=False):
932
+ maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
933
+ with maybe_no_grad:
934
+ if self.freeze_quantizer:
935
+ self.ssl_proj.eval()
936
+ self.quantizer.eval()
937
+ ssl = self.ssl_proj(ssl)
938
+ quantized, codes, commit_loss, quantized_list = self.quantizer(
939
+ ssl, layers=[0]
940
+ )
941
+
942
+ if self.semantic_frame_rate == "25hz":
943
+ quantized = F.interpolate(
944
+ quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
945
+ )
946
+
947
+ x, m_p, logs_p, y_mask = self.enc_p(
948
+ quantized, y_lengths, text, text_lengths, ge512
949
+ )
950
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
951
+ z_p = self.flow(z, y_mask, g=ge)
952
+
953
+ z_slice, ids_slice = commons.rand_slice_segments(
954
+ z, y_lengths, self.segment_size
955
+ )
956
+ o = self.dec(z_slice, g=ge)
957
+ return (
958
+ o,
959
+ commit_loss,
960
+ ids_slice,
961
+ y_mask,
962
+ y_mask,
963
+ (z, z_p, m_p, logs_p, m_q, logs_q),
964
+ quantized,
965
+ )
966
+
967
+ def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5):
968
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
969
+ y.dtype
970
+ )
971
+ if(self.version=="v1"):
972
+ ge = self.ref_enc(y * y_mask, y_mask)
973
+ else:
974
+ ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
975
+
976
+ ssl = self.ssl_proj(ssl)
977
+ quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
978
+ if self.semantic_frame_rate == "25hz":
979
+ quantized = F.interpolate(
980
+ quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
981
+ )
982
+
983
+ x, m_p, logs_p, y_mask = self.enc_p(
984
+ quantized, y_lengths, text, text_lengths, ge, test=test
985
+ )
986
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
987
+
988
+ z = self.flow(z_p, y_mask, g=ge, reverse=True)
989
+
990
+ o = self.dec((z * y_mask)[:, :, :], g=ge)
991
+ return o, y_mask, (z, z_p, m_p, logs_p)
992
+
993
+ @torch.no_grad()
994
+ def decode(self, codes, text, refer, noise_scale=0.5,speed=1, sv_emb=None):
995
+ def get_ge(refer, sv_emb):
996
+ ge = None
997
+ if refer is not None:
998
+ refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
999
+ refer_mask = torch.unsqueeze(
1000
+ commons.sequence_mask(refer_lengths, refer.size(2)), 1
1001
+ ).to(refer.dtype)
1002
+ if (self.version == "v1"):
1003
+ ge = self.ref_enc(refer * refer_mask, refer_mask)
1004
+ else:
1005
+ ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
1006
+ sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
1007
+ ge += sv_emb.unsqueeze(-1)
1008
+ ge = self.prelu(ge)
1009
+ return ge
1010
+ if(type(refer)==list):
1011
+ ges=[]
1012
+ for idx,_refer in enumerate(refer):
1013
+ ge=get_ge(_refer,sv_emb[idx])
1014
+ ges.append(ge)
1015
+ ge=torch.stack(ges,0).mean(0)
1016
+ else:
1017
+ ge = get_ge(refer, sv_emb)
1018
+
1019
+ y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
1020
+ text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
1021
+
1022
+ quantized = self.quantizer.decode(codes)
1023
+ if self.semantic_frame_rate == "25hz":
1024
+ quantized = F.interpolate(
1025
+ quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
1026
+ )
1027
+ x, m_p, logs_p, y_mask = self.enc_p(
1028
+ quantized, y_lengths, text, text_lengths, self.ge_to512(ge.transpose(2,1)).transpose(2,1),speed
1029
+ )
1030
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1031
+
1032
+ z = self.flow(z_p, y_mask, g=ge, reverse=True)
1033
+
1034
+ o = self.dec((z * y_mask)[:, :, :], g=ge)
1035
+ return o
1036
+
1037
+ def extract_latent(self, x):
1038
+ ssl = self.ssl_proj(x)
1039
+ quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
1040
+ return codes.transpose(0, 1)
module/models_onnx.py ADDED
@@ -0,0 +1,918 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from module import commons
8
+ from module import modules
9
+ from module import attentions_onnx as attentions
10
+
11
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
12
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
+ from module.commons import init_weights, get_padding
14
+ from module.mrte_model import MRTE
15
+ from module.quantize import ResidualVectorQuantizer
16
+ from text import symbols
17
+ from torch.cuda.amp import autocast
18
+
19
+
20
+ class StochasticDurationPredictor(nn.Module):
21
+ def __init__(
22
+ self,
23
+ in_channels,
24
+ filter_channels,
25
+ kernel_size,
26
+ p_dropout,
27
+ n_flows=4,
28
+ gin_channels=0,
29
+ ):
30
+ super().__init__()
31
+ filter_channels = in_channels # it needs to be removed from future version.
32
+ self.in_channels = in_channels
33
+ self.filter_channels = filter_channels
34
+ self.kernel_size = kernel_size
35
+ self.p_dropout = p_dropout
36
+ self.n_flows = n_flows
37
+ self.gin_channels = gin_channels
38
+
39
+ self.log_flow = modules.Log()
40
+ self.flows = nn.ModuleList()
41
+ self.flows.append(modules.ElementwiseAffine(2))
42
+ for i in range(n_flows):
43
+ self.flows.append(
44
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
45
+ )
46
+ self.flows.append(modules.Flip())
47
+
48
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
49
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
50
+ self.post_convs = modules.DDSConv(
51
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
52
+ )
53
+ self.post_flows = nn.ModuleList()
54
+ self.post_flows.append(modules.ElementwiseAffine(2))
55
+ for i in range(4):
56
+ self.post_flows.append(
57
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
58
+ )
59
+ self.post_flows.append(modules.Flip())
60
+
61
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
62
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
63
+ self.convs = modules.DDSConv(
64
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
65
+ )
66
+ if gin_channels != 0:
67
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
68
+
69
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
70
+ x = torch.detach(x)
71
+ x = self.pre(x)
72
+ if g is not None:
73
+ g = torch.detach(g)
74
+ x = x + self.cond(g)
75
+ x = self.convs(x, x_mask)
76
+ x = self.proj(x) * x_mask
77
+
78
+ if not reverse:
79
+ flows = self.flows
80
+ assert w is not None
81
+
82
+ logdet_tot_q = 0
83
+ h_w = self.post_pre(w)
84
+ h_w = self.post_convs(h_w, x_mask)
85
+ h_w = self.post_proj(h_w) * x_mask
86
+ e_q = (
87
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
88
+ * x_mask
89
+ )
90
+ z_q = e_q
91
+ for flow in self.post_flows:
92
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
93
+ logdet_tot_q += logdet_q
94
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
95
+ u = torch.sigmoid(z_u) * x_mask
96
+ z0 = (w - u) * x_mask
97
+ logdet_tot_q += torch.sum(
98
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
99
+ )
100
+ logq = (
101
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
102
+ - logdet_tot_q
103
+ )
104
+
105
+ logdet_tot = 0
106
+ z0, logdet = self.log_flow(z0, x_mask)
107
+ logdet_tot += logdet
108
+ z = torch.cat([z0, z1], 1)
109
+ for flow in flows:
110
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
111
+ logdet_tot = logdet_tot + logdet
112
+ nll = (
113
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
114
+ - logdet_tot
115
+ )
116
+ return nll + logq # [b]
117
+ else:
118
+ flows = list(reversed(self.flows))
119
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
120
+ z = (
121
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
122
+ * noise_scale
123
+ )
124
+ for flow in flows:
125
+ z = flow(z, x_mask, g=x, reverse=reverse)
126
+ z0, z1 = torch.split(z, [1, 1], 1)
127
+ logw = z0
128
+ return logw
129
+
130
+
131
+ class DurationPredictor(nn.Module):
132
+ def __init__(
133
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
134
+ ):
135
+ super().__init__()
136
+
137
+ self.in_channels = in_channels
138
+ self.filter_channels = filter_channels
139
+ self.kernel_size = kernel_size
140
+ self.p_dropout = p_dropout
141
+ self.gin_channels = gin_channels
142
+
143
+ self.drop = nn.Dropout(p_dropout)
144
+ self.conv_1 = nn.Conv1d(
145
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
146
+ )
147
+ self.norm_1 = modules.LayerNorm(filter_channels)
148
+ self.conv_2 = nn.Conv1d(
149
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
150
+ )
151
+ self.norm_2 = modules.LayerNorm(filter_channels)
152
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
153
+
154
+ if gin_channels != 0:
155
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
156
+
157
+ def forward(self, x, x_mask, g=None):
158
+ x = torch.detach(x)
159
+ if g is not None:
160
+ g = torch.detach(g)
161
+ x = x + self.cond(g)
162
+ x = self.conv_1(x * x_mask)
163
+ x = torch.relu(x)
164
+ x = self.norm_1(x)
165
+ x = self.drop(x)
166
+ x = self.conv_2(x * x_mask)
167
+ x = torch.relu(x)
168
+ x = self.norm_2(x)
169
+ x = self.drop(x)
170
+ x = self.proj(x * x_mask)
171
+ return x * x_mask
172
+
173
+
174
+ class TextEncoder(nn.Module):
175
+ def __init__(
176
+ self,
177
+ out_channels,
178
+ hidden_channels,
179
+ filter_channels,
180
+ n_heads,
181
+ n_layers,
182
+ kernel_size,
183
+ p_dropout,
184
+ latent_channels=192,
185
+ ):
186
+ super().__init__()
187
+ self.out_channels = out_channels
188
+ self.hidden_channels = hidden_channels
189
+ self.filter_channels = filter_channels
190
+ self.n_heads = n_heads
191
+ self.n_layers = n_layers
192
+ self.kernel_size = kernel_size
193
+ self.p_dropout = p_dropout
194
+ self.latent_channels = latent_channels
195
+
196
+ self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
197
+
198
+ self.encoder_ssl = attentions.Encoder(
199
+ hidden_channels,
200
+ filter_channels,
201
+ n_heads,
202
+ n_layers // 2,
203
+ kernel_size,
204
+ p_dropout,
205
+ )
206
+
207
+ self.encoder_text = attentions.Encoder(
208
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
209
+ )
210
+ self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
211
+
212
+ self.mrte = MRTE()
213
+
214
+ self.encoder2 = attentions.Encoder(
215
+ hidden_channels,
216
+ filter_channels,
217
+ n_heads,
218
+ n_layers // 2,
219
+ kernel_size,
220
+ p_dropout,
221
+ )
222
+
223
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
224
+
225
+ def forward(self, y, text, ge):
226
+ y_mask = torch.ones_like(y[:1,:1,:])
227
+
228
+ y = self.ssl_proj(y * y_mask) * y_mask
229
+ y = self.encoder_ssl(y * y_mask, y_mask)
230
+
231
+ text_mask = torch.ones_like(text).to(y.dtype).unsqueeze(0)
232
+
233
+ text = self.text_embedding(text).transpose(1, 2)
234
+ text = self.encoder_text(text * text_mask, text_mask)
235
+ y = self.mrte(y, y_mask, text, text_mask, ge)
236
+
237
+ y = self.encoder2(y * y_mask, y_mask)
238
+
239
+ stats = self.proj(y) * y_mask
240
+ m, logs = torch.split(stats, self.out_channels, dim=1)
241
+ return y, m, logs, y_mask
242
+
243
+ def extract_latent(self, x):
244
+ x = self.ssl_proj(x)
245
+ quantized, codes, commit_loss, quantized_list = self.quantizer(x)
246
+ return codes.transpose(0, 1)
247
+
248
+ def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
249
+ quantized = self.quantizer.decode(codes)
250
+
251
+ y = self.vq_proj(quantized) * y_mask
252
+ y = self.encoder_ssl(y * y_mask, y_mask)
253
+
254
+ y = self.mrte(y, y_mask, refer, refer_mask, ge)
255
+
256
+ y = self.encoder2(y * y_mask, y_mask)
257
+
258
+ stats = self.proj(y) * y_mask
259
+ m, logs = torch.split(stats, self.out_channels, dim=1)
260
+ return y, m, logs, y_mask, quantized
261
+
262
+
263
+ class ResidualCouplingBlock(nn.Module):
264
+ def __init__(
265
+ self,
266
+ channels,
267
+ hidden_channels,
268
+ kernel_size,
269
+ dilation_rate,
270
+ n_layers,
271
+ n_flows=4,
272
+ gin_channels=0,
273
+ ):
274
+ super().__init__()
275
+ self.channels = channels
276
+ self.hidden_channels = hidden_channels
277
+ self.kernel_size = kernel_size
278
+ self.dilation_rate = dilation_rate
279
+ self.n_layers = n_layers
280
+ self.n_flows = n_flows
281
+ self.gin_channels = gin_channels
282
+
283
+ self.flows = nn.ModuleList()
284
+ for i in range(n_flows):
285
+ self.flows.append(
286
+ modules.ResidualCouplingLayer(
287
+ channels,
288
+ hidden_channels,
289
+ kernel_size,
290
+ dilation_rate,
291
+ n_layers,
292
+ gin_channels=gin_channels,
293
+ mean_only=True,
294
+ )
295
+ )
296
+ self.flows.append(modules.Flip())
297
+
298
+ def forward(self, x, x_mask, g=None, reverse=False):
299
+ if not reverse:
300
+ for flow in self.flows:
301
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
302
+ else:
303
+ for flow in reversed(self.flows):
304
+ x = flow(x, x_mask, g=g, reverse=reverse)
305
+ return x
306
+
307
+
308
+ class PosteriorEncoder(nn.Module):
309
+ def __init__(
310
+ self,
311
+ in_channels,
312
+ out_channels,
313
+ hidden_channels,
314
+ kernel_size,
315
+ dilation_rate,
316
+ n_layers,
317
+ gin_channels=0,
318
+ ):
319
+ super().__init__()
320
+ self.in_channels = in_channels
321
+ self.out_channels = out_channels
322
+ self.hidden_channels = hidden_channels
323
+ self.kernel_size = kernel_size
324
+ self.dilation_rate = dilation_rate
325
+ self.n_layers = n_layers
326
+ self.gin_channels = gin_channels
327
+
328
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
329
+ self.enc = modules.WN(
330
+ hidden_channels,
331
+ kernel_size,
332
+ dilation_rate,
333
+ n_layers,
334
+ gin_channels=gin_channels,
335
+ )
336
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
337
+
338
+ def forward(self, x, x_lengths, g=None):
339
+ if g != None:
340
+ g = g.detach()
341
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
342
+ x.dtype
343
+ )
344
+ x = self.pre(x) * x_mask
345
+ x = self.enc(x, x_mask, g=g)
346
+ stats = self.proj(x) * x_mask
347
+ m, logs = torch.split(stats, self.out_channels, dim=1)
348
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
349
+ return z, m, logs, x_mask
350
+
351
+
352
+ class WNEncoder(nn.Module):
353
+ def __init__(
354
+ self,
355
+ in_channels,
356
+ out_channels,
357
+ hidden_channels,
358
+ kernel_size,
359
+ dilation_rate,
360
+ n_layers,
361
+ gin_channels=0,
362
+ ):
363
+ super().__init__()
364
+ self.in_channels = in_channels
365
+ self.out_channels = out_channels
366
+ self.hidden_channels = hidden_channels
367
+ self.kernel_size = kernel_size
368
+ self.dilation_rate = dilation_rate
369
+ self.n_layers = n_layers
370
+ self.gin_channels = gin_channels
371
+
372
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
373
+ self.enc = modules.WN(
374
+ hidden_channels,
375
+ kernel_size,
376
+ dilation_rate,
377
+ n_layers,
378
+ gin_channels=gin_channels,
379
+ )
380
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
381
+ self.norm = modules.LayerNorm(out_channels)
382
+
383
+ def forward(self, x, x_lengths, g=None):
384
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
385
+ x.dtype
386
+ )
387
+ x = self.pre(x) * x_mask
388
+ x = self.enc(x, x_mask, g=g)
389
+ out = self.proj(x) * x_mask
390
+ out = self.norm(out)
391
+ return out
392
+
393
+
394
+ class Generator(torch.nn.Module):
395
+ def __init__(
396
+ self,
397
+ initial_channel,
398
+ resblock,
399
+ resblock_kernel_sizes,
400
+ resblock_dilation_sizes,
401
+ upsample_rates,
402
+ upsample_initial_channel,
403
+ upsample_kernel_sizes,
404
+ gin_channels=0,
405
+ ):
406
+ super(Generator, self).__init__()
407
+ self.num_kernels = len(resblock_kernel_sizes)
408
+ self.num_upsamples = len(upsample_rates)
409
+ self.conv_pre = Conv1d(
410
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
411
+ )
412
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
413
+
414
+ self.ups = nn.ModuleList()
415
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
416
+ self.ups.append(
417
+ weight_norm(
418
+ ConvTranspose1d(
419
+ upsample_initial_channel // (2**i),
420
+ upsample_initial_channel // (2 ** (i + 1)),
421
+ k,
422
+ u,
423
+ padding=(k - u) // 2,
424
+ )
425
+ )
426
+ )
427
+
428
+ self.resblocks = nn.ModuleList()
429
+ for i in range(len(self.ups)):
430
+ ch = upsample_initial_channel // (2 ** (i + 1))
431
+ for j, (k, d) in enumerate(
432
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
433
+ ):
434
+ self.resblocks.append(resblock(ch, k, d))
435
+
436
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
437
+ self.ups.apply(init_weights)
438
+
439
+ if gin_channels != 0:
440
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
441
+
442
+ def forward(self, x, g=None):
443
+ x = self.conv_pre(x)
444
+ if g is not None:
445
+ x = x + self.cond(g)
446
+
447
+ for i in range(self.num_upsamples):
448
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
449
+ x = self.ups[i](x)
450
+ xs = None
451
+ for j in range(self.num_kernels):
452
+ if xs is None:
453
+ xs = self.resblocks[i * self.num_kernels + j](x)
454
+ else:
455
+ xs += self.resblocks[i * self.num_kernels + j](x)
456
+ x = xs / self.num_kernels
457
+ x = F.leaky_relu(x)
458
+ x = self.conv_post(x)
459
+ x = torch.tanh(x)
460
+
461
+ return x
462
+
463
+ def remove_weight_norm(self):
464
+ print("Removing weight norm...")
465
+ for l in self.ups:
466
+ remove_weight_norm(l)
467
+ for l in self.resblocks:
468
+ l.remove_weight_norm()
469
+
470
+
471
+ class DiscriminatorP(torch.nn.Module):
472
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
473
+ super(DiscriminatorP, self).__init__()
474
+ self.period = period
475
+ self.use_spectral_norm = use_spectral_norm
476
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
477
+ self.convs = nn.ModuleList(
478
+ [
479
+ norm_f(
480
+ Conv2d(
481
+ 1,
482
+ 32,
483
+ (kernel_size, 1),
484
+ (stride, 1),
485
+ padding=(get_padding(kernel_size, 1), 0),
486
+ )
487
+ ),
488
+ norm_f(
489
+ Conv2d(
490
+ 32,
491
+ 128,
492
+ (kernel_size, 1),
493
+ (stride, 1),
494
+ padding=(get_padding(kernel_size, 1), 0),
495
+ )
496
+ ),
497
+ norm_f(
498
+ Conv2d(
499
+ 128,
500
+ 512,
501
+ (kernel_size, 1),
502
+ (stride, 1),
503
+ padding=(get_padding(kernel_size, 1), 0),
504
+ )
505
+ ),
506
+ norm_f(
507
+ Conv2d(
508
+ 512,
509
+ 1024,
510
+ (kernel_size, 1),
511
+ (stride, 1),
512
+ padding=(get_padding(kernel_size, 1), 0),
513
+ )
514
+ ),
515
+ norm_f(
516
+ Conv2d(
517
+ 1024,
518
+ 1024,
519
+ (kernel_size, 1),
520
+ 1,
521
+ padding=(get_padding(kernel_size, 1), 0),
522
+ )
523
+ ),
524
+ ]
525
+ )
526
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
527
+
528
+ def forward(self, x):
529
+ fmap = []
530
+
531
+ # 1d to 2d
532
+ b, c, t = x.shape
533
+ if t % self.period != 0: # pad first
534
+ n_pad = self.period - (t % self.period)
535
+ x = F.pad(x, (0, n_pad), "reflect")
536
+ t = t + n_pad
537
+ x = x.view(b, c, t // self.period, self.period)
538
+
539
+ for l in self.convs:
540
+ x = l(x)
541
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
542
+ fmap.append(x)
543
+ x = self.conv_post(x)
544
+ fmap.append(x)
545
+ x = torch.flatten(x, 1, -1)
546
+
547
+ return x, fmap
548
+
549
+
550
+ class DiscriminatorS(torch.nn.Module):
551
+ def __init__(self, use_spectral_norm=False):
552
+ super(DiscriminatorS, self).__init__()
553
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
554
+ self.convs = nn.ModuleList(
555
+ [
556
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
557
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
558
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
559
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
560
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
561
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
562
+ ]
563
+ )
564
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
565
+
566
+ def forward(self, x):
567
+ fmap = []
568
+
569
+ for l in self.convs:
570
+ x = l(x)
571
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
572
+ fmap.append(x)
573
+ x = self.conv_post(x)
574
+ fmap.append(x)
575
+ x = torch.flatten(x, 1, -1)
576
+
577
+ return x, fmap
578
+
579
+
580
+ class MultiPeriodDiscriminator(torch.nn.Module):
581
+ def __init__(self, use_spectral_norm=False):
582
+ super(MultiPeriodDiscriminator, self).__init__()
583
+ periods = [2, 3, 5, 7, 11]
584
+
585
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
586
+ discs = discs + [
587
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
588
+ ]
589
+ self.discriminators = nn.ModuleList(discs)
590
+
591
+ def forward(self, y, y_hat):
592
+ y_d_rs = []
593
+ y_d_gs = []
594
+ fmap_rs = []
595
+ fmap_gs = []
596
+ for i, d in enumerate(self.discriminators):
597
+ y_d_r, fmap_r = d(y)
598
+ y_d_g, fmap_g = d(y_hat)
599
+ y_d_rs.append(y_d_r)
600
+ y_d_gs.append(y_d_g)
601
+ fmap_rs.append(fmap_r)
602
+ fmap_gs.append(fmap_g)
603
+
604
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
605
+
606
+
607
+ class ReferenceEncoder(nn.Module):
608
+ """
609
+ inputs --- [N, Ty/r, n_mels*r] mels
610
+ outputs --- [N, ref_enc_gru_size]
611
+ """
612
+
613
+ def __init__(self, spec_channels, gin_channels=0):
614
+ super().__init__()
615
+ self.spec_channels = spec_channels
616
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
617
+ K = len(ref_enc_filters)
618
+ filters = [1] + ref_enc_filters
619
+ convs = [
620
+ weight_norm(
621
+ nn.Conv2d(
622
+ in_channels=filters[i],
623
+ out_channels=filters[i + 1],
624
+ kernel_size=(3, 3),
625
+ stride=(2, 2),
626
+ padding=(1, 1),
627
+ )
628
+ )
629
+ for i in range(K)
630
+ ]
631
+ self.convs = nn.ModuleList(convs)
632
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
633
+
634
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
635
+ self.gru = nn.GRU(
636
+ input_size=ref_enc_filters[-1] * out_channels,
637
+ hidden_size=256 // 2,
638
+ batch_first=True,
639
+ )
640
+ self.proj = nn.Linear(128, gin_channels)
641
+
642
+ def forward(self, inputs):
643
+ N = inputs.size(0)
644
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
645
+ for conv in self.convs:
646
+ out = conv(out)
647
+ # out = wn(out)
648
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
649
+
650
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
651
+ T = out.size(1)
652
+ N = out.size(0)
653
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
654
+
655
+ self.gru.flatten_parameters()
656
+ memory, out = self.gru(out) # out --- [1, N, 128]
657
+
658
+ return self.proj(out.squeeze(0)).unsqueeze(-1)
659
+
660
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
661
+ for i in range(n_convs):
662
+ L = (L - kernel_size + 2 * pad) // stride + 1
663
+ return L
664
+
665
+
666
+ class Quantizer_module(torch.nn.Module):
667
+ def __init__(self, n_e, e_dim):
668
+ super(Quantizer_module, self).__init__()
669
+ self.embedding = nn.Embedding(n_e, e_dim)
670
+ self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e)
671
+
672
+ def forward(self, x):
673
+ d = (
674
+ torch.sum(x**2, 1, keepdim=True)
675
+ + torch.sum(self.embedding.weight**2, 1)
676
+ - 2 * torch.matmul(x, self.embedding.weight.T)
677
+ )
678
+ min_indicies = torch.argmin(d, 1)
679
+ z_q = self.embedding(min_indicies)
680
+ return z_q, min_indicies
681
+
682
+
683
+ class Quantizer(torch.nn.Module):
684
+ def __init__(self, embed_dim=512, n_code_groups=4, n_codes=160):
685
+ super(Quantizer, self).__init__()
686
+ assert embed_dim % n_code_groups == 0
687
+ self.quantizer_modules = nn.ModuleList(
688
+ [
689
+ Quantizer_module(n_codes, embed_dim // n_code_groups)
690
+ for _ in range(n_code_groups)
691
+ ]
692
+ )
693
+ self.n_code_groups = n_code_groups
694
+ self.embed_dim = embed_dim
695
+
696
+ def forward(self, xin):
697
+ # B, C, T
698
+ B, C, T = xin.shape
699
+ xin = xin.transpose(1, 2)
700
+ x = xin.reshape(-1, self.embed_dim)
701
+ x = torch.split(x, self.embed_dim // self.n_code_groups, dim=-1)
702
+ min_indicies = []
703
+ z_q = []
704
+ for _x, m in zip(x, self.quantizer_modules):
705
+ _z_q, _min_indicies = m(_x)
706
+ z_q.append(_z_q)
707
+ min_indicies.append(_min_indicies) # B * T,
708
+ z_q = torch.cat(z_q, -1).reshape(xin.shape)
709
+ loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
710
+ (z_q - xin.detach()) ** 2
711
+ )
712
+ z_q = xin + (z_q - xin).detach()
713
+ z_q = z_q.transpose(1, 2)
714
+ codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
715
+ return z_q, loss, codes.transpose(1, 2)
716
+
717
+ def embed(self, x):
718
+ # idx: N, 4, T
719
+ x = x.transpose(1, 2)
720
+ x = torch.split(x, 1, 2)
721
+ ret = []
722
+ for q, embed in zip(x, self.quantizer_modules):
723
+ q = embed.embedding(q.squeeze(-1))
724
+ ret.append(q)
725
+ ret = torch.cat(ret, -1)
726
+ return ret.transpose(1, 2) # N, C, T
727
+
728
+
729
+ class CodePredictor(nn.Module):
730
+ def __init__(
731
+ self,
732
+ hidden_channels,
733
+ filter_channels,
734
+ n_heads,
735
+ n_layers,
736
+ kernel_size,
737
+ p_dropout,
738
+ n_q=8,
739
+ dims=1024,
740
+ ssl_dim=768,
741
+ ):
742
+ super().__init__()
743
+ self.hidden_channels = hidden_channels
744
+ self.filter_channels = filter_channels
745
+ self.n_heads = n_heads
746
+ self.n_layers = n_layers
747
+ self.kernel_size = kernel_size
748
+ self.p_dropout = p_dropout
749
+
750
+ self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
751
+ self.ref_enc = modules.MelStyleEncoder(
752
+ ssl_dim, style_vector_dim=hidden_channels
753
+ )
754
+
755
+ self.encoder = attentions.Encoder(
756
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
757
+ )
758
+
759
+ self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
760
+ self.n_q = n_q
761
+ self.dims = dims
762
+
763
+ def forward(self, x, x_mask, refer, codes, infer=False):
764
+ x = x.detach()
765
+ x = self.vq_proj(x * x_mask) * x_mask
766
+ g = self.ref_enc(refer, x_mask)
767
+ x = x + g
768
+ x = self.encoder(x * x_mask, x_mask)
769
+ x = self.out_proj(x * x_mask) * x_mask
770
+ logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
771
+ 2, 3
772
+ )
773
+ target = codes[1:].transpose(0, 1)
774
+ if not infer:
775
+ logits = logits.reshape(-1, self.dims)
776
+ target = target.reshape(-1)
777
+ loss = torch.nn.functional.cross_entropy(logits, target)
778
+ return loss
779
+ else:
780
+ _, top10_preds = torch.topk(logits, 10, dim=-1)
781
+ correct_top10 = torch.any(top10_preds == target.unsqueeze(-1), dim=-1)
782
+ top3_acc = 100 * torch.mean(correct_top10.float()).detach().cpu().item()
783
+
784
+ print("Top-10 Accuracy:", top3_acc, "%")
785
+
786
+ pred_codes = torch.argmax(logits, dim=-1)
787
+ acc = 100 * torch.mean((pred_codes == target).float()).detach().cpu().item()
788
+ print("Top-1 Accuracy:", acc, "%")
789
+
790
+ return pred_codes.transpose(0, 1)
791
+
792
+
793
+ class SynthesizerTrn(nn.Module):
794
+ """
795
+ Synthesizer for Training
796
+ """
797
+
798
+ def __init__(
799
+ self,
800
+ spec_channels,
801
+ segment_size,
802
+ inter_channels,
803
+ hidden_channels,
804
+ filter_channels,
805
+ n_heads,
806
+ n_layers,
807
+ kernel_size,
808
+ p_dropout,
809
+ resblock,
810
+ resblock_kernel_sizes,
811
+ resblock_dilation_sizes,
812
+ upsample_rates,
813
+ upsample_initial_channel,
814
+ upsample_kernel_sizes,
815
+ n_speakers=0,
816
+ gin_channels=0,
817
+ use_sdp=True,
818
+ semantic_frame_rate=None,
819
+ freeze_quantizer=None,
820
+ **kwargs
821
+ ):
822
+ super().__init__()
823
+ self.spec_channels = spec_channels
824
+ self.inter_channels = inter_channels
825
+ self.hidden_channels = hidden_channels
826
+ self.filter_channels = filter_channels
827
+ self.n_heads = n_heads
828
+ self.n_layers = n_layers
829
+ self.kernel_size = kernel_size
830
+ self.p_dropout = p_dropout
831
+ self.resblock = resblock
832
+ self.resblock_kernel_sizes = resblock_kernel_sizes
833
+ self.resblock_dilation_sizes = resblock_dilation_sizes
834
+ self.upsample_rates = upsample_rates
835
+ self.upsample_initial_channel = upsample_initial_channel
836
+ self.upsample_kernel_sizes = upsample_kernel_sizes
837
+ self.segment_size = segment_size
838
+ self.n_speakers = n_speakers
839
+ self.gin_channels = gin_channels
840
+
841
+ self.use_sdp = use_sdp
842
+ self.enc_p = TextEncoder(
843
+ inter_channels,
844
+ hidden_channels,
845
+ filter_channels,
846
+ n_heads,
847
+ n_layers,
848
+ kernel_size,
849
+ p_dropout,
850
+ )
851
+ self.dec = Generator(
852
+ inter_channels,
853
+ resblock,
854
+ resblock_kernel_sizes,
855
+ resblock_dilation_sizes,
856
+ upsample_rates,
857
+ upsample_initial_channel,
858
+ upsample_kernel_sizes,
859
+ gin_channels=gin_channels,
860
+ )
861
+ self.enc_q = PosteriorEncoder(
862
+ spec_channels,
863
+ inter_channels,
864
+ hidden_channels,
865
+ 5,
866
+ 1,
867
+ 16,
868
+ gin_channels=gin_channels,
869
+ )
870
+ self.flow = ResidualCouplingBlock(
871
+ inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
872
+ )
873
+
874
+ self.ref_enc = modules.MelStyleEncoder(
875
+ spec_channels, style_vector_dim=gin_channels
876
+ )
877
+
878
+ ssl_dim = 768
879
+ self.ssl_dim = ssl_dim
880
+ assert semantic_frame_rate in ["25hz", "50hz"]
881
+ self.semantic_frame_rate = semantic_frame_rate
882
+ if semantic_frame_rate == "25hz":
883
+ self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
884
+ else:
885
+ self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
886
+
887
+ self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
888
+ if freeze_quantizer:
889
+ self.ssl_proj.requires_grad_(False)
890
+ self.quantizer.requires_grad_(False)
891
+ # self.enc_p.text_embedding.requires_grad_(False)
892
+ # self.enc_p.encoder_text.requires_grad_(False)
893
+ # self.enc_p.mrte.requires_grad_(False)
894
+
895
+ def forward(self, codes, text, refer):
896
+ refer_mask = torch.ones_like(refer[:1,:1,:])
897
+ ge = self.ref_enc(refer * refer_mask, refer_mask)
898
+
899
+ quantized = self.quantizer.decode(codes)
900
+ if self.semantic_frame_rate == "25hz":
901
+ dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
902
+ quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
903
+
904
+ x, m_p, logs_p, y_mask = self.enc_p(
905
+ quantized, text, ge
906
+ )
907
+
908
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p)
909
+
910
+ z = self.flow(z_p, y_mask, g=ge, reverse=True)
911
+
912
+ o = self.dec((z * y_mask)[:, :, :], g=ge)
913
+ return o
914
+
915
+ def extract_latent(self, x):
916
+ ssl = self.ssl_proj(x)
917
+ quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
918
+ return codes.transpose(0, 1)
module/modules.py ADDED
@@ -0,0 +1,923 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from torch.nn import Conv1d
8
+ from torch.nn.utils import weight_norm, remove_weight_norm
9
+
10
+ from module import commons
11
+ from module.commons import init_weights, get_padding
12
+ from module.transforms import piecewise_rational_quadratic_transform
13
+ import torch.distributions as D
14
+
15
+
16
+ LRELU_SLOPE = 0.1
17
+
18
+
19
+ class LayerNorm(nn.Module):
20
+ def __init__(self, channels, eps=1e-5):
21
+ super().__init__()
22
+ self.channels = channels
23
+ self.eps = eps
24
+
25
+ self.gamma = nn.Parameter(torch.ones(channels))
26
+ self.beta = nn.Parameter(torch.zeros(channels))
27
+
28
+ def forward(self, x):
29
+ x = x.transpose(1, -1)
30
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
31
+ return x.transpose(1, -1)
32
+
33
+
34
+ class ConvReluNorm(nn.Module):
35
+ def __init__(
36
+ self,
37
+ in_channels,
38
+ hidden_channels,
39
+ out_channels,
40
+ kernel_size,
41
+ n_layers,
42
+ p_dropout,
43
+ ):
44
+ super().__init__()
45
+ self.in_channels = in_channels
46
+ self.hidden_channels = hidden_channels
47
+ self.out_channels = out_channels
48
+ self.kernel_size = kernel_size
49
+ self.n_layers = n_layers
50
+ self.p_dropout = p_dropout
51
+ assert n_layers > 1, "Number of layers should be larger than 0."
52
+
53
+ self.conv_layers = nn.ModuleList()
54
+ self.norm_layers = nn.ModuleList()
55
+ self.conv_layers.append(
56
+ nn.Conv1d(
57
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
58
+ )
59
+ )
60
+ self.norm_layers.append(LayerNorm(hidden_channels))
61
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
62
+ for _ in range(n_layers - 1):
63
+ self.conv_layers.append(
64
+ nn.Conv1d(
65
+ hidden_channels,
66
+ hidden_channels,
67
+ kernel_size,
68
+ padding=kernel_size // 2,
69
+ )
70
+ )
71
+ self.norm_layers.append(LayerNorm(hidden_channels))
72
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
73
+ self.proj.weight.data.zero_()
74
+ self.proj.bias.data.zero_()
75
+
76
+ def forward(self, x, x_mask):
77
+ x_org = x
78
+ for i in range(self.n_layers):
79
+ x = self.conv_layers[i](x * x_mask)
80
+ x = self.norm_layers[i](x)
81
+ x = self.relu_drop(x)
82
+ x = x_org + self.proj(x)
83
+ return x * x_mask
84
+
85
+
86
+ class DDSConv(nn.Module):
87
+ """
88
+ Dialted and Depth-Separable Convolution
89
+ """
90
+
91
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
92
+ super().__init__()
93
+ self.channels = channels
94
+ self.kernel_size = kernel_size
95
+ self.n_layers = n_layers
96
+ self.p_dropout = p_dropout
97
+
98
+ self.drop = nn.Dropout(p_dropout)
99
+ self.convs_sep = nn.ModuleList()
100
+ self.convs_1x1 = nn.ModuleList()
101
+ self.norms_1 = nn.ModuleList()
102
+ self.norms_2 = nn.ModuleList()
103
+ for i in range(n_layers):
104
+ dilation = kernel_size**i
105
+ padding = (kernel_size * dilation - dilation) // 2
106
+ self.convs_sep.append(
107
+ nn.Conv1d(
108
+ channels,
109
+ channels,
110
+ kernel_size,
111
+ groups=channels,
112
+ dilation=dilation,
113
+ padding=padding,
114
+ )
115
+ )
116
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
117
+ self.norms_1.append(LayerNorm(channels))
118
+ self.norms_2.append(LayerNorm(channels))
119
+
120
+ def forward(self, x, x_mask, g=None):
121
+ if g is not None:
122
+ x = x + g
123
+ for i in range(self.n_layers):
124
+ y = self.convs_sep[i](x * x_mask)
125
+ y = self.norms_1[i](y)
126
+ y = F.gelu(y)
127
+ y = self.convs_1x1[i](y)
128
+ y = self.norms_2[i](y)
129
+ y = F.gelu(y)
130
+ y = self.drop(y)
131
+ x = x + y
132
+ return x * x_mask
133
+
134
+
135
+ class WN(torch.nn.Module):
136
+ def __init__(
137
+ self,
138
+ hidden_channels,
139
+ kernel_size,
140
+ dilation_rate,
141
+ n_layers,
142
+ gin_channels=0,
143
+ p_dropout=0,
144
+ ):
145
+ super(WN, self).__init__()
146
+ assert kernel_size % 2 == 1
147
+ self.hidden_channels = hidden_channels
148
+ self.kernel_size = (kernel_size,)
149
+ self.dilation_rate = dilation_rate
150
+ self.n_layers = n_layers
151
+ self.gin_channels = gin_channels
152
+ self.p_dropout = p_dropout
153
+
154
+ self.in_layers = torch.nn.ModuleList()
155
+ self.res_skip_layers = torch.nn.ModuleList()
156
+ self.drop = nn.Dropout(p_dropout)
157
+
158
+ if gin_channels != 0:
159
+ cond_layer = torch.nn.Conv1d(
160
+ gin_channels, 2 * hidden_channels * n_layers, 1
161
+ )
162
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
163
+
164
+ for i in range(n_layers):
165
+ dilation = dilation_rate**i
166
+ padding = int((kernel_size * dilation - dilation) / 2)
167
+ in_layer = torch.nn.Conv1d(
168
+ hidden_channels,
169
+ 2 * hidden_channels,
170
+ kernel_size,
171
+ dilation=dilation,
172
+ padding=padding,
173
+ )
174
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
175
+ self.in_layers.append(in_layer)
176
+
177
+ # last one is not necessary
178
+ if i < n_layers - 1:
179
+ res_skip_channels = 2 * hidden_channels
180
+ else:
181
+ res_skip_channels = hidden_channels
182
+
183
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
184
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
185
+ self.res_skip_layers.append(res_skip_layer)
186
+
187
+ def forward(self, x, x_mask, g=None, **kwargs):
188
+ output = torch.zeros_like(x)
189
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
190
+
191
+ if g is not None:
192
+ g = self.cond_layer(g)
193
+
194
+ for i in range(self.n_layers):
195
+ x_in = self.in_layers[i](x)
196
+ if g is not None:
197
+ cond_offset = i * 2 * self.hidden_channels
198
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
199
+ else:
200
+ g_l = torch.zeros_like(x_in)
201
+
202
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
203
+ acts = self.drop(acts)
204
+
205
+ res_skip_acts = self.res_skip_layers[i](acts)
206
+ if i < self.n_layers - 1:
207
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
208
+ x = (x + res_acts) * x_mask
209
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
210
+ else:
211
+ output = output + res_skip_acts
212
+ return output * x_mask
213
+
214
+ def remove_weight_norm(self):
215
+ if self.gin_channels != 0:
216
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
217
+ for l in self.in_layers:
218
+ torch.nn.utils.remove_weight_norm(l)
219
+ for l in self.res_skip_layers:
220
+ torch.nn.utils.remove_weight_norm(l)
221
+
222
+
223
+ class ResBlock1(torch.nn.Module):
224
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
225
+ super(ResBlock1, self).__init__()
226
+ self.convs1 = nn.ModuleList(
227
+ [
228
+ weight_norm(
229
+ Conv1d(
230
+ channels,
231
+ channels,
232
+ kernel_size,
233
+ 1,
234
+ dilation=dilation[0],
235
+ padding=get_padding(kernel_size, dilation[0]),
236
+ )
237
+ ),
238
+ weight_norm(
239
+ Conv1d(
240
+ channels,
241
+ channels,
242
+ kernel_size,
243
+ 1,
244
+ dilation=dilation[1],
245
+ padding=get_padding(kernel_size, dilation[1]),
246
+ )
247
+ ),
248
+ weight_norm(
249
+ Conv1d(
250
+ channels,
251
+ channels,
252
+ kernel_size,
253
+ 1,
254
+ dilation=dilation[2],
255
+ padding=get_padding(kernel_size, dilation[2]),
256
+ )
257
+ ),
258
+ ]
259
+ )
260
+ self.convs1.apply(init_weights)
261
+
262
+ self.convs2 = nn.ModuleList(
263
+ [
264
+ weight_norm(
265
+ Conv1d(
266
+ channels,
267
+ channels,
268
+ kernel_size,
269
+ 1,
270
+ dilation=1,
271
+ padding=get_padding(kernel_size, 1),
272
+ )
273
+ ),
274
+ weight_norm(
275
+ Conv1d(
276
+ channels,
277
+ channels,
278
+ kernel_size,
279
+ 1,
280
+ dilation=1,
281
+ padding=get_padding(kernel_size, 1),
282
+ )
283
+ ),
284
+ weight_norm(
285
+ Conv1d(
286
+ channels,
287
+ channels,
288
+ kernel_size,
289
+ 1,
290
+ dilation=1,
291
+ padding=get_padding(kernel_size, 1),
292
+ )
293
+ ),
294
+ ]
295
+ )
296
+ self.convs2.apply(init_weights)
297
+
298
+ def forward(self, x, x_mask=None):
299
+ for c1, c2 in zip(self.convs1, self.convs2):
300
+ xt = F.leaky_relu(x, LRELU_SLOPE)
301
+ if x_mask is not None:
302
+ xt = xt * x_mask
303
+ xt = c1(xt)
304
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
305
+ if x_mask is not None:
306
+ xt = xt * x_mask
307
+ xt = c2(xt)
308
+ x = xt + x
309
+ if x_mask is not None:
310
+ x = x * x_mask
311
+ return x
312
+
313
+ def remove_weight_norm(self):
314
+ for l in self.convs1:
315
+ remove_weight_norm(l)
316
+ for l in self.convs2:
317
+ remove_weight_norm(l)
318
+
319
+
320
+ class ResBlock2(torch.nn.Module):
321
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
322
+ super(ResBlock2, self).__init__()
323
+ self.convs = nn.ModuleList(
324
+ [
325
+ weight_norm(
326
+ Conv1d(
327
+ channels,
328
+ channels,
329
+ kernel_size,
330
+ 1,
331
+ dilation=dilation[0],
332
+ padding=get_padding(kernel_size, dilation[0]),
333
+ )
334
+ ),
335
+ weight_norm(
336
+ Conv1d(
337
+ channels,
338
+ channels,
339
+ kernel_size,
340
+ 1,
341
+ dilation=dilation[1],
342
+ padding=get_padding(kernel_size, dilation[1]),
343
+ )
344
+ ),
345
+ ]
346
+ )
347
+ self.convs.apply(init_weights)
348
+
349
+ def forward(self, x, x_mask=None):
350
+ for c in self.convs:
351
+ xt = F.leaky_relu(x, LRELU_SLOPE)
352
+ if x_mask is not None:
353
+ xt = xt * x_mask
354
+ xt = c(xt)
355
+ x = xt + x
356
+ if x_mask is not None:
357
+ x = x * x_mask
358
+ return x
359
+
360
+ def remove_weight_norm(self):
361
+ for l in self.convs:
362
+ remove_weight_norm(l)
363
+
364
+
365
+ class Log(nn.Module):
366
+ def forward(self, x, x_mask, reverse=False, **kwargs):
367
+ if not reverse:
368
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
369
+ logdet = torch.sum(-y, [1, 2])
370
+ return y, logdet
371
+ else:
372
+ x = torch.exp(x) * x_mask
373
+ return x
374
+
375
+
376
+ class Flip(nn.Module):
377
+ def forward(self, x, *args, reverse=False, **kwargs):
378
+ x = torch.flip(x, [1])
379
+ if not reverse:
380
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
381
+ return x, logdet
382
+ else:
383
+ return x
384
+
385
+
386
+ class ElementwiseAffine(nn.Module):
387
+ def __init__(self, channels):
388
+ super().__init__()
389
+ self.channels = channels
390
+ self.m = nn.Parameter(torch.zeros(channels, 1))
391
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
392
+
393
+ def forward(self, x, x_mask, reverse=False, **kwargs):
394
+ if not reverse:
395
+ y = self.m + torch.exp(self.logs) * x
396
+ y = y * x_mask
397
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
398
+ return y, logdet
399
+ else:
400
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
401
+ return x
402
+
403
+
404
+ class ResidualCouplingLayer(nn.Module):
405
+ def __init__(
406
+ self,
407
+ channels,
408
+ hidden_channels,
409
+ kernel_size,
410
+ dilation_rate,
411
+ n_layers,
412
+ p_dropout=0,
413
+ gin_channels=0,
414
+ mean_only=False,
415
+ ):
416
+ assert channels % 2 == 0, "channels should be divisible by 2"
417
+ super().__init__()
418
+ self.channels = channels
419
+ self.hidden_channels = hidden_channels
420
+ self.kernel_size = kernel_size
421
+ self.dilation_rate = dilation_rate
422
+ self.n_layers = n_layers
423
+ self.half_channels = channels // 2
424
+ self.mean_only = mean_only
425
+
426
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
427
+ self.enc = WN(
428
+ hidden_channels,
429
+ kernel_size,
430
+ dilation_rate,
431
+ n_layers,
432
+ p_dropout=p_dropout,
433
+ gin_channels=gin_channels,
434
+ )
435
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
436
+ self.post.weight.data.zero_()
437
+ self.post.bias.data.zero_()
438
+
439
+ def forward(self, x, x_mask, g=None, reverse=False):
440
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
441
+ h = self.pre(x0) * x_mask
442
+ h = self.enc(h, x_mask, g=g)
443
+ stats = self.post(h) * x_mask
444
+ if not self.mean_only:
445
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
446
+ else:
447
+ m = stats
448
+ logs = torch.zeros_like(m)
449
+
450
+ if not reverse:
451
+ x1 = m + x1 * torch.exp(logs) * x_mask
452
+ x = torch.cat([x0, x1], 1)
453
+ logdet = torch.sum(logs, [1, 2])
454
+ return x, logdet
455
+ else:
456
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
457
+ x = torch.cat([x0, x1], 1)
458
+ return x
459
+
460
+
461
+ class ConvFlow(nn.Module):
462
+ def __init__(
463
+ self,
464
+ in_channels,
465
+ filter_channels,
466
+ kernel_size,
467
+ n_layers,
468
+ num_bins=10,
469
+ tail_bound=5.0,
470
+ ):
471
+ super().__init__()
472
+ self.in_channels = in_channels
473
+ self.filter_channels = filter_channels
474
+ self.kernel_size = kernel_size
475
+ self.n_layers = n_layers
476
+ self.num_bins = num_bins
477
+ self.tail_bound = tail_bound
478
+ self.half_channels = in_channels // 2
479
+
480
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
481
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
482
+ self.proj = nn.Conv1d(
483
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
484
+ )
485
+ self.proj.weight.data.zero_()
486
+ self.proj.bias.data.zero_()
487
+
488
+ def forward(self, x, x_mask, g=None, reverse=False):
489
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
490
+ h = self.pre(x0)
491
+ h = self.convs(h, x_mask, g=g)
492
+ h = self.proj(h) * x_mask
493
+
494
+ b, c, t = x0.shape
495
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
496
+
497
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
498
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
499
+ self.filter_channels
500
+ )
501
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
502
+
503
+ x1, logabsdet = piecewise_rational_quadratic_transform(
504
+ x1,
505
+ unnormalized_widths,
506
+ unnormalized_heights,
507
+ unnormalized_derivatives,
508
+ inverse=reverse,
509
+ tails="linear",
510
+ tail_bound=self.tail_bound,
511
+ )
512
+
513
+ x = torch.cat([x0, x1], 1) * x_mask
514
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
515
+ if not reverse:
516
+ return x, logdet
517
+ else:
518
+ return x
519
+
520
+
521
+ class LinearNorm(nn.Module):
522
+ def __init__(
523
+ self,
524
+ in_channels,
525
+ out_channels,
526
+ bias=True,
527
+ spectral_norm=False,
528
+ ):
529
+ super(LinearNorm, self).__init__()
530
+ self.fc = nn.Linear(in_channels, out_channels, bias)
531
+
532
+ if spectral_norm:
533
+ self.fc = nn.utils.spectral_norm(self.fc)
534
+
535
+ def forward(self, input):
536
+ out = self.fc(input)
537
+ return out
538
+
539
+
540
+ class Mish(nn.Module):
541
+ def __init__(self):
542
+ super(Mish, self).__init__()
543
+
544
+ def forward(self, x):
545
+ return x * torch.tanh(F.softplus(x))
546
+
547
+
548
+ class Conv1dGLU(nn.Module):
549
+ """
550
+ Conv1d + GLU(Gated Linear Unit) with residual connection.
551
+ For GLU refer to https://arxiv.org/abs/1612.08083 paper.
552
+ """
553
+
554
+ def __init__(self, in_channels, out_channels, kernel_size, dropout):
555
+ super(Conv1dGLU, self).__init__()
556
+ self.out_channels = out_channels
557
+ self.conv1 = ConvNorm(in_channels, 2 * out_channels, kernel_size=kernel_size)
558
+ self.dropout = nn.Dropout(dropout)
559
+
560
+ def forward(self, x):
561
+ residual = x
562
+ x = self.conv1(x)
563
+ x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1)
564
+ x = x1 * torch.sigmoid(x2)
565
+ x = residual + self.dropout(x)
566
+ return x
567
+
568
+
569
+ class ConvNorm(nn.Module):
570
+ def __init__(
571
+ self,
572
+ in_channels,
573
+ out_channels,
574
+ kernel_size=1,
575
+ stride=1,
576
+ padding=None,
577
+ dilation=1,
578
+ bias=True,
579
+ spectral_norm=False,
580
+ ):
581
+ super(ConvNorm, self).__init__()
582
+
583
+ if padding is None:
584
+ assert kernel_size % 2 == 1
585
+ padding = int(dilation * (kernel_size - 1) / 2)
586
+
587
+ self.conv = torch.nn.Conv1d(
588
+ in_channels,
589
+ out_channels,
590
+ kernel_size=kernel_size,
591
+ stride=stride,
592
+ padding=padding,
593
+ dilation=dilation,
594
+ bias=bias,
595
+ )
596
+
597
+ if spectral_norm:
598
+ self.conv = nn.utils.spectral_norm(self.conv)
599
+
600
+ def forward(self, input):
601
+ out = self.conv(input)
602
+ return out
603
+
604
+
605
+ class MultiHeadAttention(nn.Module):
606
+ """Multi-Head Attention module"""
607
+
608
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.0, spectral_norm=False):
609
+ super().__init__()
610
+
611
+ self.n_head = n_head
612
+ self.d_k = d_k
613
+ self.d_v = d_v
614
+
615
+ self.w_qs = nn.Linear(d_model, n_head * d_k)
616
+ self.w_ks = nn.Linear(d_model, n_head * d_k)
617
+ self.w_vs = nn.Linear(d_model, n_head * d_v)
618
+
619
+ self.attention = ScaledDotProductAttention(
620
+ temperature=np.power(d_model, 0.5), dropout=dropout
621
+ )
622
+
623
+ self.fc = nn.Linear(n_head * d_v, d_model)
624
+ self.dropout = nn.Dropout(dropout)
625
+
626
+ if spectral_norm:
627
+ self.w_qs = nn.utils.spectral_norm(self.w_qs)
628
+ self.w_ks = nn.utils.spectral_norm(self.w_ks)
629
+ self.w_vs = nn.utils.spectral_norm(self.w_vs)
630
+ self.fc = nn.utils.spectral_norm(self.fc)
631
+
632
+ def forward(self, x, mask=None):
633
+ d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
634
+ sz_b, len_x, _ = x.size()
635
+
636
+ residual = x
637
+
638
+ q = self.w_qs(x).view(sz_b, len_x, n_head, d_k)
639
+ k = self.w_ks(x).view(sz_b, len_x, n_head, d_k)
640
+ v = self.w_vs(x).view(sz_b, len_x, n_head, d_v)
641
+ q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k) # (n*b) x lq x dk
642
+ k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k) # (n*b) x lk x dk
643
+ v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_v) # (n*b) x lv x dv
644
+
645
+ if mask is not None:
646
+ slf_mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
647
+ else:
648
+ slf_mask = None
649
+ output, attn = self.attention(q, k, v, mask=slf_mask)
650
+
651
+ output = output.view(n_head, sz_b, len_x, d_v)
652
+ output = (
653
+ output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1)
654
+ ) # b x lq x (n*dv)
655
+
656
+ output = self.fc(output)
657
+
658
+ output = self.dropout(output) + residual
659
+ return output, attn
660
+
661
+
662
+ class ScaledDotProductAttention(nn.Module):
663
+ """Scaled Dot-Product Attention"""
664
+
665
+ def __init__(self, temperature, dropout):
666
+ super().__init__()
667
+ self.temperature = temperature
668
+ self.softmax = nn.Softmax(dim=2)
669
+ self.dropout = nn.Dropout(dropout)
670
+
671
+ def forward(self, q, k, v, mask=None):
672
+ attn = torch.bmm(q, k.transpose(1, 2))
673
+ attn = attn / self.temperature
674
+
675
+ if mask is not None:
676
+ attn = attn.masked_fill(mask, -np.inf)
677
+
678
+ attn = self.softmax(attn)
679
+ p_attn = self.dropout(attn)
680
+
681
+ output = torch.bmm(p_attn, v)
682
+ return output, attn
683
+
684
+
685
+ class MelStyleEncoder(nn.Module):
686
+ """MelStyleEncoder"""
687
+
688
+ def __init__(
689
+ self,
690
+ n_mel_channels=80,
691
+ style_hidden=128,
692
+ style_vector_dim=256,
693
+ style_kernel_size=5,
694
+ style_head=2,
695
+ dropout=0.1,
696
+ ):
697
+ super(MelStyleEncoder, self).__init__()
698
+ self.in_dim = n_mel_channels
699
+ self.hidden_dim = style_hidden
700
+ self.out_dim = style_vector_dim
701
+ self.kernel_size = style_kernel_size
702
+ self.n_head = style_head
703
+ self.dropout = dropout
704
+
705
+ self.spectral = nn.Sequential(
706
+ LinearNorm(self.in_dim, self.hidden_dim),
707
+ Mish(),
708
+ nn.Dropout(self.dropout),
709
+ LinearNorm(self.hidden_dim, self.hidden_dim),
710
+ Mish(),
711
+ nn.Dropout(self.dropout),
712
+ )
713
+
714
+ self.temporal = nn.Sequential(
715
+ Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
716
+ Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
717
+ )
718
+
719
+ self.slf_attn = MultiHeadAttention(
720
+ self.n_head,
721
+ self.hidden_dim,
722
+ self.hidden_dim // self.n_head,
723
+ self.hidden_dim // self.n_head,
724
+ self.dropout,
725
+ )
726
+
727
+ self.fc = LinearNorm(self.hidden_dim, self.out_dim)
728
+
729
+ def temporal_avg_pool(self, x, mask=None):
730
+ if mask is None:
731
+ out = torch.mean(x, dim=1)
732
+ else:
733
+ len_ = (~mask).sum(dim=1).unsqueeze(1)
734
+ x = x.masked_fill(mask.unsqueeze(-1), 0)
735
+ x = x.sum(dim=1)
736
+ out = torch.div(x, len_)
737
+ return out
738
+
739
+ def forward(self, x, mask=None):
740
+ x = x.transpose(1, 2)
741
+ if mask is not None:
742
+ mask = (mask.int() == 0).squeeze(1)
743
+ max_len = x.shape[1]
744
+ slf_attn_mask = (
745
+ mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
746
+ )
747
+
748
+ # spectral
749
+ x = self.spectral(x)
750
+ # temporal
751
+ x = x.transpose(1, 2)
752
+ x = self.temporal(x)
753
+ x = x.transpose(1, 2)
754
+ # self-attention
755
+ if mask is not None:
756
+ x = x.masked_fill(mask.unsqueeze(-1), 0)
757
+ x, _ = self.slf_attn(x, mask=slf_attn_mask)
758
+ # fc
759
+ x = self.fc(x)
760
+ # temoral average pooling
761
+ w = self.temporal_avg_pool(x, mask=mask)
762
+
763
+ return w.unsqueeze(-1)
764
+
765
+
766
+ class MelStyleEncoderVAE(nn.Module):
767
+ def __init__(self, spec_channels, z_latent_dim, emb_dim):
768
+ super().__init__()
769
+ self.ref_encoder = MelStyleEncoder(spec_channels, style_vector_dim=emb_dim)
770
+ self.fc1 = nn.Linear(emb_dim, z_latent_dim)
771
+ self.fc2 = nn.Linear(emb_dim, z_latent_dim)
772
+ self.fc3 = nn.Linear(z_latent_dim, emb_dim)
773
+ self.z_latent_dim = z_latent_dim
774
+
775
+ def reparameterize(self, mu, logvar):
776
+ if self.training:
777
+ std = torch.exp(0.5 * logvar)
778
+ eps = torch.randn_like(std)
779
+ return eps.mul(std).add_(mu)
780
+ else:
781
+ return mu
782
+
783
+ def forward(self, inputs, mask=None):
784
+ enc_out = self.ref_encoder(inputs.squeeze(-1), mask).squeeze(-1)
785
+ mu = self.fc1(enc_out)
786
+ logvar = self.fc2(enc_out)
787
+ posterior = D.Normal(mu, torch.exp(logvar))
788
+ kl_divergence = D.kl_divergence(
789
+ posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar))
790
+ )
791
+ loss_kl = kl_divergence.mean()
792
+
793
+ z = posterior.rsample()
794
+ style_embed = self.fc3(z)
795
+
796
+ return style_embed.unsqueeze(-1), loss_kl
797
+
798
+ def infer(self, inputs=None, random_sample=False, manual_latent=None):
799
+ if manual_latent is None:
800
+ if random_sample:
801
+ dev = next(self.parameters()).device
802
+ posterior = D.Normal(
803
+ torch.zeros(1, self.z_latent_dim, device=dev),
804
+ torch.ones(1, self.z_latent_dim, device=dev),
805
+ )
806
+ z = posterior.rsample()
807
+ else:
808
+ enc_out = self.ref_encoder(inputs.transpose(1, 2))
809
+ mu = self.fc1(enc_out)
810
+ z = mu
811
+ else:
812
+ z = manual_latent
813
+ style_embed = self.fc3(z)
814
+ return style_embed.unsqueeze(-1), z
815
+
816
+
817
+ class ActNorm(nn.Module):
818
+ def __init__(self, channels, ddi=False, **kwargs):
819
+ super().__init__()
820
+ self.channels = channels
821
+ self.initialized = not ddi
822
+
823
+ self.logs = nn.Parameter(torch.zeros(1, channels, 1))
824
+ self.bias = nn.Parameter(torch.zeros(1, channels, 1))
825
+
826
+ def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs):
827
+ if x_mask is None:
828
+ x_mask = torch.ones(x.size(0), 1, x.size(2)).to(
829
+ device=x.device, dtype=x.dtype
830
+ )
831
+ x_len = torch.sum(x_mask, [1, 2])
832
+ if not self.initialized:
833
+ self.initialize(x, x_mask)
834
+ self.initialized = True
835
+
836
+ if reverse:
837
+ z = (x - self.bias) * torch.exp(-self.logs) * x_mask
838
+ logdet = None
839
+ return z
840
+ else:
841
+ z = (self.bias + torch.exp(self.logs) * x) * x_mask
842
+ logdet = torch.sum(self.logs) * x_len # [b]
843
+ return z, logdet
844
+
845
+ def store_inverse(self):
846
+ pass
847
+
848
+ def set_ddi(self, ddi):
849
+ self.initialized = not ddi
850
+
851
+ def initialize(self, x, x_mask):
852
+ with torch.no_grad():
853
+ denom = torch.sum(x_mask, [0, 2])
854
+ m = torch.sum(x * x_mask, [0, 2]) / denom
855
+ m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
856
+ v = m_sq - (m**2)
857
+ logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
858
+
859
+ bias_init = (
860
+ (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
861
+ )
862
+ logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
863
+
864
+ self.bias.data.copy_(bias_init)
865
+ self.logs.data.copy_(logs_init)
866
+
867
+
868
+ class InvConvNear(nn.Module):
869
+ def __init__(self, channels, n_split=4, no_jacobian=False, **kwargs):
870
+ super().__init__()
871
+ assert n_split % 2 == 0
872
+ self.channels = channels
873
+ self.n_split = n_split
874
+ self.no_jacobian = no_jacobian
875
+
876
+ w_init = torch.linalg.qr(
877
+ torch.FloatTensor(self.n_split, self.n_split).normal_()
878
+ )[0]
879
+ if torch.det(w_init) < 0:
880
+ w_init[:, 0] = -1 * w_init[:, 0]
881
+ self.weight = nn.Parameter(w_init)
882
+
883
+ def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs):
884
+ b, c, t = x.size()
885
+ assert c % self.n_split == 0
886
+ if x_mask is None:
887
+ x_mask = 1
888
+ x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
889
+ else:
890
+ x_len = torch.sum(x_mask, [1, 2])
891
+
892
+ x = x.view(b, 2, c // self.n_split, self.n_split // 2, t)
893
+ x = (
894
+ x.permute(0, 1, 3, 2, 4)
895
+ .contiguous()
896
+ .view(b, self.n_split, c // self.n_split, t)
897
+ )
898
+
899
+ if reverse:
900
+ if hasattr(self, "weight_inv"):
901
+ weight = self.weight_inv
902
+ else:
903
+ weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
904
+ logdet = None
905
+ else:
906
+ weight = self.weight
907
+ if self.no_jacobian:
908
+ logdet = 0
909
+ else:
910
+ logdet = torch.logdet(self.weight) * (c / self.n_split) * x_len # [b]
911
+
912
+ weight = weight.view(self.n_split, self.n_split, 1, 1)
913
+ z = F.conv2d(x, weight)
914
+
915
+ z = z.view(b, 2, self.n_split // 2, c // self.n_split, t)
916
+ z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask
917
+ if reverse:
918
+ return z
919
+ else:
920
+ return z, logdet
921
+
922
+ def store_inverse(self):
923
+ self.weight_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
module/mrte_model.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is Multi-reference timbre encoder
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.utils import remove_weight_norm, weight_norm
6
+ from module.attentions import MultiHeadAttention
7
+
8
+
9
+ class MRTE(nn.Module):
10
+ def __init__(
11
+ self,
12
+ content_enc_channels=192,
13
+ hidden_size=512,
14
+ out_channels=192,
15
+ kernel_size=5,
16
+ n_heads=4,
17
+ ge_layer=2,
18
+ ):
19
+ super(MRTE, self).__init__()
20
+ self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)
21
+ self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
22
+ self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
23
+ self.c_post = nn.Conv1d(hidden_size, out_channels, 1)
24
+
25
+ def forward(self, ssl_enc, ssl_mask, text, text_mask, ge, test=None):
26
+ if ge == None:
27
+ ge = 0
28
+ attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)
29
+
30
+ ssl_enc = self.c_pre(ssl_enc * ssl_mask)
31
+ text_enc = self.text_pre(text * text_mask)
32
+ if test != None:
33
+ if test == 0:
34
+ x = (
35
+ self.cross_attention(
36
+ ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
37
+ )
38
+ + ssl_enc
39
+ + ge
40
+ )
41
+ elif test == 1:
42
+ x = ssl_enc + ge
43
+ elif test == 2:
44
+ x = (
45
+ self.cross_attention(
46
+ ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask
47
+ )
48
+ + ge
49
+ )
50
+ else:
51
+ raise ValueError("test should be 0,1,2")
52
+ else:
53
+ x = (
54
+ self.cross_attention(
55
+ ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
56
+ )
57
+ + ssl_enc
58
+ + ge
59
+ )
60
+ x = self.c_post(x * ssl_mask)
61
+ return x
62
+
63
+
64
+ class SpeakerEncoder(torch.nn.Module):
65
+ def __init__(
66
+ self,
67
+ mel_n_channels=80,
68
+ model_num_layers=2,
69
+ model_hidden_size=256,
70
+ model_embedding_size=256,
71
+ ):
72
+ super(SpeakerEncoder, self).__init__()
73
+ self.lstm = nn.LSTM(
74
+ mel_n_channels, model_hidden_size, model_num_layers, batch_first=True
75
+ )
76
+ self.linear = nn.Linear(model_hidden_size, model_embedding_size)
77
+ self.relu = nn.ReLU()
78
+
79
+ def forward(self, mels):
80
+ self.lstm.flatten_parameters()
81
+ _, (hidden, _) = self.lstm(mels.transpose(-1, -2))
82
+ embeds_raw = self.relu(self.linear(hidden[-1]))
83
+ return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
84
+
85
+
86
+ class MELEncoder(nn.Module):
87
+ def __init__(
88
+ self,
89
+ in_channels,
90
+ out_channels,
91
+ hidden_channels,
92
+ kernel_size,
93
+ dilation_rate,
94
+ n_layers,
95
+ ):
96
+ super().__init__()
97
+ self.in_channels = in_channels
98
+ self.out_channels = out_channels
99
+ self.hidden_channels = hidden_channels
100
+ self.kernel_size = kernel_size
101
+ self.dilation_rate = dilation_rate
102
+ self.n_layers = n_layers
103
+
104
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
105
+ self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers)
106
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
107
+
108
+ def forward(self, x):
109
+ # print(x.shape,x_lengths.shape)
110
+ x = self.pre(x)
111
+ x = self.enc(x)
112
+ x = self.proj(x)
113
+ return x
114
+
115
+
116
+ class WN(torch.nn.Module):
117
+ def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers):
118
+ super(WN, self).__init__()
119
+ assert kernel_size % 2 == 1
120
+ self.hidden_channels = hidden_channels
121
+ self.kernel_size = kernel_size
122
+ self.dilation_rate = dilation_rate
123
+ self.n_layers = n_layers
124
+
125
+ self.in_layers = torch.nn.ModuleList()
126
+ self.res_skip_layers = torch.nn.ModuleList()
127
+
128
+ for i in range(n_layers):
129
+ dilation = dilation_rate**i
130
+ padding = int((kernel_size * dilation - dilation) / 2)
131
+ in_layer = nn.Conv1d(
132
+ hidden_channels,
133
+ 2 * hidden_channels,
134
+ kernel_size,
135
+ dilation=dilation,
136
+ padding=padding,
137
+ )
138
+ in_layer = weight_norm(in_layer)
139
+ self.in_layers.append(in_layer)
140
+
141
+ # last one is not necessary
142
+ if i < n_layers - 1:
143
+ res_skip_channels = 2 * hidden_channels
144
+ else:
145
+ res_skip_channels = hidden_channels
146
+
147
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
148
+ res_skip_layer = weight_norm(res_skip_layer, name="weight")
149
+ self.res_skip_layers.append(res_skip_layer)
150
+
151
+ def forward(self, x):
152
+ output = torch.zeros_like(x)
153
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
154
+
155
+ for i in range(self.n_layers):
156
+ x_in = self.in_layers[i](x)
157
+
158
+ acts = fused_add_tanh_sigmoid_multiply(x_in, n_channels_tensor)
159
+
160
+ res_skip_acts = self.res_skip_layers[i](acts)
161
+ if i < self.n_layers - 1:
162
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
163
+ x = x + res_acts
164
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
165
+ else:
166
+ output = output + res_skip_acts
167
+ return output
168
+
169
+ def remove_weight_norm(self):
170
+ for l in self.in_layers:
171
+ remove_weight_norm(l)
172
+ for l in self.res_skip_layers:
173
+ remove_weight_norm(l)
174
+
175
+
176
+ @torch.jit.script
177
+ def fused_add_tanh_sigmoid_multiply(input, n_channels):
178
+ n_channels_int = n_channels[0]
179
+ t_act = torch.tanh(input[:, :n_channels_int, :])
180
+ s_act = torch.sigmoid(input[:, n_channels_int:, :])
181
+ acts = t_act * s_act
182
+ return acts
183
+
184
+
185
+ if __name__ == "__main__":
186
+ content_enc = torch.randn(3, 192, 100)
187
+ content_mask = torch.ones(3, 1, 100)
188
+ ref_mel = torch.randn(3, 128, 30)
189
+ ref_mask = torch.ones(3, 1, 30)
190
+ model = MRTE()
191
+ out = model(content_enc, content_mask, ref_mel, ref_mask)
192
+ print(out.shape)
module/quantize.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Residual vector quantizer implementation."""
8
+
9
+ from dataclasses import dataclass, field
10
+ import math
11
+ import typing as tp
12
+
13
+ import torch
14
+ from torch import nn
15
+
16
+ from module.core_vq import ResidualVectorQuantization
17
+
18
+
19
+ @dataclass
20
+ class QuantizedResult:
21
+ quantized: torch.Tensor
22
+ codes: torch.Tensor
23
+ bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
24
+ penalty: tp.Optional[torch.Tensor] = None
25
+ metrics: dict = field(default_factory=dict)
26
+
27
+
28
+ class ResidualVectorQuantizer(nn.Module):
29
+ """Residual Vector Quantizer.
30
+ Args:
31
+ dimension (int): Dimension of the codebooks.
32
+ n_q (int): Number of residual vector quantizers used.
33
+ bins (int): Codebook size.
34
+ decay (float): Decay for exponential moving average over the codebooks.
35
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
36
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
37
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
38
+ that have an exponential moving average cluster size less than the specified threshold with
39
+ randomly selected vector from the current batch.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ dimension: int = 256,
45
+ n_q: int = 8,
46
+ bins: int = 1024,
47
+ decay: float = 0.99,
48
+ kmeans_init: bool = True,
49
+ kmeans_iters: int = 50,
50
+ threshold_ema_dead_code: int = 2,
51
+ ):
52
+ super().__init__()
53
+ self.n_q = n_q
54
+ self.dimension = dimension
55
+ self.bins = bins
56
+ self.decay = decay
57
+ self.kmeans_init = kmeans_init
58
+ self.kmeans_iters = kmeans_iters
59
+ self.threshold_ema_dead_code = threshold_ema_dead_code
60
+ self.vq = ResidualVectorQuantization(
61
+ dim=self.dimension,
62
+ codebook_size=self.bins,
63
+ num_quantizers=self.n_q,
64
+ decay=self.decay,
65
+ kmeans_init=self.kmeans_init,
66
+ kmeans_iters=self.kmeans_iters,
67
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
68
+ )
69
+
70
+ def forward(
71
+ self,
72
+ x: torch.Tensor,
73
+ n_q: tp.Optional[int] = None,
74
+ layers: tp.Optional[list] = None,
75
+ ) -> QuantizedResult:
76
+ """Residual vector quantization on the given input tensor.
77
+ Args:
78
+ x (torch.Tensor): Input tensor.
79
+ n_q (int): Number of quantizer used to quantize. Default: All quantizers.
80
+ layers (list): Layer that need to return quantized. Defalt: None.
81
+ Returns:
82
+ QuantizedResult:
83
+ The quantized (or approximately quantized) representation with
84
+ the associated numbert quantizers and layer quantized required to return.
85
+ """
86
+ n_q = n_q if n_q else self.n_q
87
+ if layers and max(layers) >= n_q:
88
+ raise ValueError(
89
+ f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
90
+ )
91
+ quantized, codes, commit_loss, quantized_list = self.vq(
92
+ x, n_q=n_q, layers=layers
93
+ )
94
+ return quantized, codes, torch.mean(commit_loss), quantized_list
95
+
96
+ def encode(
97
+ self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
98
+ ) -> torch.Tensor:
99
+ """Encode a given input tensor with the specified sample rate at the given bandwidth.
100
+ The RVQ encode method sets the appropriate number of quantizer to use
101
+ and returns indices for each quantizer.
102
+ Args:
103
+ x (torch.Tensor): Input tensor.
104
+ n_q (int): Number of quantizer used to quantize. Default: All quantizers.
105
+ st (int): Start to encode input from which layers. Default: 0.
106
+ """
107
+ n_q = n_q if n_q else self.n_q
108
+ st = st or 0
109
+ codes = self.vq.encode(x, n_q=n_q, st=st)
110
+ return codes
111
+
112
+ def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor:
113
+ """Decode the given codes to the quantized representation.
114
+ Args:
115
+ codes (torch.Tensor): Input indices for each quantizer.
116
+ st (int): Start to decode input codes from which layers. Default: 0.
117
+ """
118
+ quantized = self.vq.decode(codes, st=st)
119
+ return quantized
module/transforms.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+ import numpy as np
5
+
6
+
7
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
8
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
+ DEFAULT_MIN_DERIVATIVE = 1e-3
10
+
11
+
12
+ def piecewise_rational_quadratic_transform(
13
+ inputs,
14
+ unnormalized_widths,
15
+ unnormalized_heights,
16
+ unnormalized_derivatives,
17
+ inverse=False,
18
+ tails=None,
19
+ tail_bound=1.0,
20
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
21
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
22
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
23
+ ):
24
+ if tails is None:
25
+ spline_fn = rational_quadratic_spline
26
+ spline_kwargs = {}
27
+ else:
28
+ spline_fn = unconstrained_rational_quadratic_spline
29
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
30
+
31
+ outputs, logabsdet = spline_fn(
32
+ inputs=inputs,
33
+ unnormalized_widths=unnormalized_widths,
34
+ unnormalized_heights=unnormalized_heights,
35
+ unnormalized_derivatives=unnormalized_derivatives,
36
+ inverse=inverse,
37
+ min_bin_width=min_bin_width,
38
+ min_bin_height=min_bin_height,
39
+ min_derivative=min_derivative,
40
+ **spline_kwargs
41
+ )
42
+ return outputs, logabsdet
43
+
44
+
45
+ def searchsorted(bin_locations, inputs, eps=1e-6):
46
+ bin_locations[..., -1] += eps
47
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
48
+
49
+
50
+ def unconstrained_rational_quadratic_spline(
51
+ inputs,
52
+ unnormalized_widths,
53
+ unnormalized_heights,
54
+ unnormalized_derivatives,
55
+ inverse=False,
56
+ tails="linear",
57
+ tail_bound=1.0,
58
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
59
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
60
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
61
+ ):
62
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
63
+ outside_interval_mask = ~inside_interval_mask
64
+
65
+ outputs = torch.zeros_like(inputs)
66
+ logabsdet = torch.zeros_like(inputs)
67
+
68
+ if tails == "linear":
69
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
70
+ constant = np.log(np.exp(1 - min_derivative) - 1)
71
+ unnormalized_derivatives[..., 0] = constant
72
+ unnormalized_derivatives[..., -1] = constant
73
+
74
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
75
+ logabsdet[outside_interval_mask] = 0
76
+ else:
77
+ raise RuntimeError("{} tails are not implemented.".format(tails))
78
+
79
+ (
80
+ outputs[inside_interval_mask],
81
+ logabsdet[inside_interval_mask],
82
+ ) = rational_quadratic_spline(
83
+ inputs=inputs[inside_interval_mask],
84
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
+ inverse=inverse,
88
+ left=-tail_bound,
89
+ right=tail_bound,
90
+ bottom=-tail_bound,
91
+ top=tail_bound,
92
+ min_bin_width=min_bin_width,
93
+ min_bin_height=min_bin_height,
94
+ min_derivative=min_derivative,
95
+ )
96
+
97
+ return outputs, logabsdet
98
+
99
+
100
+ def rational_quadratic_spline(
101
+ inputs,
102
+ unnormalized_widths,
103
+ unnormalized_heights,
104
+ unnormalized_derivatives,
105
+ inverse=False,
106
+ left=0.0,
107
+ right=1.0,
108
+ bottom=0.0,
109
+ top=1.0,
110
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
111
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
112
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
113
+ ):
114
+ if torch.min(inputs) < left or torch.max(inputs) > right:
115
+ raise ValueError("Input to a transform is not within its domain")
116
+
117
+ num_bins = unnormalized_widths.shape[-1]
118
+
119
+ if min_bin_width * num_bins > 1.0:
120
+ raise ValueError("Minimal bin width too large for the number of bins")
121
+ if min_bin_height * num_bins > 1.0:
122
+ raise ValueError("Minimal bin height too large for the number of bins")
123
+
124
+ widths = F.softmax(unnormalized_widths, dim=-1)
125
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
126
+ cumwidths = torch.cumsum(widths, dim=-1)
127
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
128
+ cumwidths = (right - left) * cumwidths + left
129
+ cumwidths[..., 0] = left
130
+ cumwidths[..., -1] = right
131
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
132
+
133
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
134
+
135
+ heights = F.softmax(unnormalized_heights, dim=-1)
136
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
137
+ cumheights = torch.cumsum(heights, dim=-1)
138
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
139
+ cumheights = (top - bottom) * cumheights + bottom
140
+ cumheights[..., 0] = bottom
141
+ cumheights[..., -1] = top
142
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
143
+
144
+ if inverse:
145
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
146
+ else:
147
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
148
+
149
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
150
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
151
+
152
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
153
+ delta = heights / widths
154
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
155
+
156
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
157
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
158
+
159
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
160
+
161
+ if inverse:
162
+ a = (inputs - input_cumheights) * (
163
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
164
+ ) + input_heights * (input_delta - input_derivatives)
165
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
166
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
167
+ )
168
+ c = -input_delta * (inputs - input_cumheights)
169
+
170
+ discriminant = b.pow(2) - 4 * a * c
171
+ assert (discriminant >= 0).all()
172
+
173
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
174
+ outputs = root * input_bin_widths + input_cumwidths
175
+
176
+ theta_one_minus_theta = root * (1 - root)
177
+ denominator = input_delta + (
178
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
179
+ * theta_one_minus_theta
180
+ )
181
+ derivative_numerator = input_delta.pow(2) * (
182
+ input_derivatives_plus_one * root.pow(2)
183
+ + 2 * input_delta * theta_one_minus_theta
184
+ + input_derivatives * (1 - root).pow(2)
185
+ )
186
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
187
+
188
+ return outputs, -logabsdet
189
+ else:
190
+ theta = (inputs - input_cumwidths) / input_bin_widths
191
+ theta_one_minus_theta = theta * (1 - theta)
192
+
193
+ numerator = input_heights * (
194
+ input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
195
+ )
196
+ denominator = input_delta + (
197
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
198
+ * theta_one_minus_theta
199
+ )
200
+ outputs = input_cumheights + numerator / denominator
201
+
202
+ derivative_numerator = input_delta.pow(2) * (
203
+ input_derivatives_plus_one * theta.pow(2)
204
+ + 2 * input_delta * theta_one_minus_theta
205
+ + input_derivatives * (1 - theta).pow(2)
206
+ )
207
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
208
+
209
+ return outputs, logabsdet
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
pre-requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch==2.5.1
2
+ torchaudio
pretrained_models/chinese-hubert-base/config.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/data/docker/liujing04/gpt-vits/chinese-hubert-base",
3
+ "activation_dropout": 0.1,
4
+ "apply_spec_augment": true,
5
+ "architectures": [
6
+ "HubertModel"
7
+ ],
8
+ "attention_dropout": 0.1,
9
+ "bos_token_id": 1,
10
+ "classifier_proj_size": 256,
11
+ "conv_bias": false,
12
+ "conv_dim": [
13
+ 512,
14
+ 512,
15
+ 512,
16
+ 512,
17
+ 512,
18
+ 512,
19
+ 512
20
+ ],
21
+ "conv_kernel": [
22
+ 10,
23
+ 3,
24
+ 3,
25
+ 3,
26
+ 3,
27
+ 2,
28
+ 2
29
+ ],
30
+ "conv_stride": [
31
+ 5,
32
+ 2,
33
+ 2,
34
+ 2,
35
+ 2,
36
+ 2,
37
+ 2
38
+ ],
39
+ "ctc_loss_reduction": "sum",
40
+ "ctc_zero_infinity": false,
41
+ "do_stable_layer_norm": false,
42
+ "eos_token_id": 2,
43
+ "feat_extract_activation": "gelu",
44
+ "feat_extract_norm": "group",
45
+ "feat_proj_dropout": 0.0,
46
+ "feat_proj_layer_norm": true,
47
+ "final_dropout": 0.1,
48
+ "hidden_act": "gelu",
49
+ "hidden_dropout": 0.1,
50
+ "hidden_size": 768,
51
+ "initializer_range": 0.02,
52
+ "intermediate_size": 3072,
53
+ "layer_norm_eps": 1e-05,
54
+ "layerdrop": 0.1,
55
+ "mask_feature_length": 10,
56
+ "mask_feature_min_masks": 0,
57
+ "mask_feature_prob": 0.0,
58
+ "mask_time_length": 10,
59
+ "mask_time_min_masks": 2,
60
+ "mask_time_prob": 0.05,
61
+ "model_type": "hubert",
62
+ "num_attention_heads": 12,
63
+ "num_conv_pos_embedding_groups": 16,
64
+ "num_conv_pos_embeddings": 128,
65
+ "num_feat_extract_layers": 7,
66
+ "num_hidden_layers": 12,
67
+ "pad_token_id": 0,
68
+ "torch_dtype": "float16",
69
+ "transformers_version": "4.30.2",
70
+ "use_weighted_layer_sum": false,
71
+ "vocab_size": 32
72
+ }
pretrained_models/chinese-hubert-base/preprocessor_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0,
7
+ "return_attention_mask": false,
8
+ "sampling_rate": 16000
9
+ }
pretrained_models/chinese-roberta-wwm-ext-large/config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large",
3
+ "architectures": [
4
+ "BertForMaskedLM"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": 0,
8
+ "classifier_dropout": null,
9
+ "directionality": "bidi",
10
+ "eos_token_id": 2,
11
+ "hidden_act": "gelu",
12
+ "hidden_dropout_prob": 0.1,
13
+ "hidden_size": 1024,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 4096,
16
+ "layer_norm_eps": 1e-12,
17
+ "max_position_embeddings": 512,
18
+ "model_type": "bert",
19
+ "num_attention_heads": 16,
20
+ "num_hidden_layers": 24,
21
+ "output_past": true,
22
+ "pad_token_id": 0,
23
+ "pooler_fc_size": 768,
24
+ "pooler_num_attention_heads": 12,
25
+ "pooler_num_fc_layers": 3,
26
+ "pooler_size_per_head": 128,
27
+ "pooler_type": "first_token_transform",
28
+ "position_embedding_type": "absolute",
29
+ "torch_dtype": "float16",
30
+ "transformers_version": "4.30.2",
31
+ "type_vocab_size": 2,
32
+ "use_cache": true,
33
+ "vocab_size": 21128
34
+ }
pretrained_models/chinese-roberta-wwm-ext-large/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
process_ckpt.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ from collections import OrderedDict
3
+ from time import time as ttime
4
+ import shutil,os
5
+ import torch
6
+ from tools.i18n.i18n import I18nAuto
7
+
8
+ i18n = I18nAuto()
9
+
10
+ def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
11
+ dir=os.path.dirname(path)
12
+ name=os.path.basename(path)
13
+ tmp_path="%s.pth"%(ttime())
14
+ torch.save(fea,tmp_path)
15
+ shutil.move(tmp_path,"%s/%s"%(dir,name))
16
+
17
+ def savee(ckpt, name, epoch, steps, hps):
18
+ try:
19
+ opt = OrderedDict()
20
+ opt["weight"] = {}
21
+ for key in ckpt.keys():
22
+ if "enc_q" in key:
23
+ continue
24
+ opt["weight"][key] = ckpt[key].half()
25
+ opt["config"] = hps
26
+ opt["info"] = "%sepoch_%siteration" % (epoch, steps)
27
+ # torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
28
+ my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
29
+ return "Success."
30
+ except:
31
+ return traceback.format_exc()
requirements.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy<2.0
2
+ scipy>=1.11.3
3
+ tensorboard==2.15.1
4
+ librosa==0.9.2
5
+ numba==0.56.4
6
+ pytorch-lightning>=2.4
7
+ ffmpeg-python==0.2.0
8
+ onnxruntime-gpu
9
+ tqdm==4.66.4
10
+ cn2an==0.5.22
11
+ pypinyin==0.50.0
12
+ pyopenjtalk==0.4.1
13
+ g2p_en==2.1.0
14
+ sentencepiece==0.1.99
15
+ transformers==4.43.0
16
+ chardet==3.0.4
17
+ PyYAML==6.0.1
18
+ psutil==5.9.7
19
+ jieba_fast==0.53
20
+ jieba==0.42.1
21
+ https://hf-mirror.com/lj1995/GPT-SoVITS-windows-package/resolve/main/langsegment-0.3.5-py3-none-any.whl?download=true
22
+ wordsegment==1.3.1
23
+ rotary_embedding_torch==0.6.4
24
+ spaces
25
+ pyjyutping==1.0.0
26
+ g2pk2==0.0.3
27
+ ko_pron==1.3
28
+ opencc==1.1.0
29
+ python_mecab_ko==1.3.7
30
+ pydantic==2.8.2
31
+ torchmetrics<=1.5
32
+ nltk==3.8.1
33
+ fast_langdetect==0.3.1
34
+ split_lang==2.1.0
35
+ ToJyutping==3.2.0
36
+ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
sv.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys,os,torch
2
+ sys.path.append(f"{os.getcwd()}/eres2net")
3
+ sv_path = "pretrained_models/sv/pretrained_eres2netv2w24s4ep4.ckpt"
4
+ from ERes2NetV2 import ERes2NetV2
5
+ import kaldi as Kaldi
6
+ class SV:
7
+ def __init__(self,device,is_half):
8
+ pretrained_state = torch.load(sv_path, map_location='cpu', weights_only=False)
9
+ embedding_model = ERes2NetV2(baseWidth=24,scale=4,expansion=4)
10
+ embedding_model.load_state_dict(pretrained_state)
11
+ embedding_model.eval()
12
+ self.embedding_model=embedding_model
13
+ if is_half == False:
14
+ self.embedding_model=self.embedding_model.to(device)
15
+ else:
16
+ self.embedding_model=self.embedding_model.half().to(device)
17
+ self.is_half=is_half
18
+
19
+ def compute_embedding3(self,wav):
20
+ with torch.no_grad():
21
+ if self.is_half==True:wav=wav.half()
22
+ feat = torch.stack([Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav])
23
+ sv_emb = self.embedding_model.forward3(feat)
24
+ return sv_emb
text/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ G2PWModel
2
+ __pycache__
3
+ *.zip
text/LangSegmenter/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .langsegmenter import LangSegmenter
text/LangSegmenter/langsegmenter.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+
4
+ # jieba静音
5
+ import jieba
6
+ jieba.setLogLevel(logging.CRITICAL)
7
+
8
+ # 更改fast_langdetect大模型位置
9
+ from pathlib import Path
10
+ import fast_langdetect
11
+ fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(fast_langdetect.infer.LangDetectConfig(cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"))
12
+
13
+
14
+ from split_lang import LangSplitter
15
+
16
+
17
+ def full_en(text):
18
+ pattern = r'^(?=.*[A-Za-z])[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$'
19
+ return bool(re.match(pattern, text))
20
+
21
+
22
+ def full_cjk(text):
23
+ # 来自wiki
24
+ cjk_ranges = [
25
+ (0x4E00, 0x9FFF), # CJK Unified Ideographs
26
+ (0x3400, 0x4DB5), # CJK Extension A
27
+ (0x20000, 0x2A6DD), # CJK Extension B
28
+ (0x2A700, 0x2B73F), # CJK Extension C
29
+ (0x2B740, 0x2B81F), # CJK Extension D
30
+ (0x2B820, 0x2CEAF), # CJK Extension E
31
+ (0x2CEB0, 0x2EBEF), # CJK Extension F
32
+ (0x30000, 0x3134A), # CJK Extension G
33
+ (0x31350, 0x323AF), # CJK Extension H
34
+ (0x2EBF0, 0x2EE5D), # CJK Extension H
35
+ ]
36
+
37
+ pattern = r'[0-9、-〜。!?.!?… /]+$'
38
+
39
+ cjk_text = ""
40
+ for char in text:
41
+ code_point = ord(char)
42
+ in_cjk = any(start <= code_point <= end for start, end in cjk_ranges)
43
+ if in_cjk or re.match(pattern, char):
44
+ cjk_text += char
45
+ return cjk_text
46
+
47
+
48
+ def split_jako(tag_lang,item):
49
+ if tag_lang == "ja":
50
+ pattern = r"([\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]+(?:[0-9、-〜。!?.!?… ]+[\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]*)*)"
51
+ else:
52
+ pattern = r"([\u1100-\u11FF\u3130-\u318F\uAC00-\uD7AF]+(?:[0-9、-〜。!?.!?… ]+[\u1100-\u11FF\u3130-\u318F\uAC00-\uD7AF]*)*)"
53
+
54
+ lang_list: list[dict] = []
55
+ tag = 0
56
+ for match in re.finditer(pattern, item['text']):
57
+ if match.start() > tag:
58
+ lang_list.append({'lang':item['lang'],'text':item['text'][tag:match.start()]})
59
+
60
+ tag = match.end()
61
+ lang_list.append({'lang':tag_lang,'text':item['text'][match.start():match.end()]})
62
+
63
+ if tag < len(item['text']):
64
+ lang_list.append({'lang':item['lang'],'text':item['text'][tag:len(item['text'])]})
65
+
66
+ return lang_list
67
+
68
+
69
+ def merge_lang(lang_list, item):
70
+ if lang_list and item['lang'] == lang_list[-1]['lang']:
71
+ lang_list[-1]['text'] += item['text']
72
+ else:
73
+ lang_list.append(item)
74
+ return lang_list
75
+
76
+
77
+ class LangSegmenter():
78
+ # 默认过滤器, 基于gsv目前四种语言
79
+ DEFAULT_LANG_MAP = {
80
+ "zh": "zh",
81
+ "yue": "zh", # 粤语
82
+ "wuu": "zh", # 吴语
83
+ "zh-cn": "zh",
84
+ "zh-tw": "x", # 繁体设置为x
85
+ "ko": "ko",
86
+ "ja": "ja",
87
+ "en": "en",
88
+ }
89
+
90
+
91
+ def getTexts(text):
92
+ lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP)
93
+ substr = lang_splitter.split_by_lang(text=text)
94
+
95
+ lang_list: list[dict] = []
96
+
97
+ for _, item in enumerate(substr):
98
+ dict_item = {'lang':item.lang,'text':item.text}
99
+
100
+ # 处理短英文被识别为其他语言的问题
101
+ if full_en(dict_item['text']):
102
+ dict_item['lang'] = 'en'
103
+ lang_list = merge_lang(lang_list,dict_item)
104
+ continue
105
+
106
+ # 处理非日语夹日文的问题(不包含CJK)
107
+ ja_list: list[dict] = []
108
+ if dict_item['lang'] != 'ja':
109
+ ja_list = split_jako('ja',dict_item)
110
+
111
+ if not ja_list:
112
+ ja_list.append(dict_item)
113
+
114
+ # 处理非韩语夹韩语的问题(不包含CJK)
115
+ ko_list: list[dict] = []
116
+ temp_list: list[dict] = []
117
+ for _, ko_item in enumerate(ja_list):
118
+ if ko_item["lang"] != 'ko':
119
+ ko_list = split_jako('ko',ko_item)
120
+
121
+ if ko_list:
122
+ temp_list.extend(ko_list)
123
+ else:
124
+ temp_list.append(ko_item)
125
+
126
+ # 未存在非日韩文夹日韩文
127
+ if len(temp_list) == 1:
128
+ # 未知语言检查是否为CJK
129
+ if dict_item['lang'] == 'x':
130
+ cjk_text = full_cjk(dict_item['text'])
131
+ if cjk_text:
132
+ dict_item = {'lang':'zh','text':cjk_text}
133
+ lang_list = merge_lang(lang_list,dict_item)
134
+ else:
135
+ lang_list = merge_lang(lang_list,dict_item)
136
+ continue
137
+ else:
138
+ lang_list = merge_lang(lang_list,dict_item)
139
+ continue
140
+
141
+ # 存在非日韩文夹日韩文
142
+ for _, temp_item in enumerate(temp_list):
143
+ # 未知语言检查是否为CJK
144
+ if temp_item['lang'] == 'x':
145
+ cjk_text = full_cjk(dict_item['text'])
146
+ if cjk_text:
147
+ dict_item = {'lang':'zh','text':cjk_text}
148
+ lang_list = merge_lang(lang_list,dict_item)
149
+ else:
150
+ lang_list = merge_lang(lang_list,dict_item)
151
+ else:
152
+ lang_list = merge_lang(lang_list,temp_item)
153
+
154
+ temp_list = lang_list
155
+ lang_list = []
156
+ for _, temp_item in enumerate(temp_list):
157
+ if temp_item['lang'] == 'x':
158
+ if lang_list:
159
+ temp_item['lang'] = lang_list[-1]['lang']
160
+ elif len(temp_list) > 1:
161
+ temp_item['lang'] = temp_list[1]['lang']
162
+ else:
163
+ temp_item['lang'] = 'zh'
164
+
165
+ lang_list = merge_lang(lang_list,temp_item)
166
+
167
+ return lang_list
168
+
169
+
170
+ if __name__ == "__main__":
171
+ text = "MyGO?,你也喜欢まいご吗?"
172
+ print(LangSegmenter.getTexts(text))
173
+
174
+ text = "ねえ、知ってる?最近、僕は天文学を勉強してるんだ。君の瞳が星空みたいにキラキラしてるからさ。"
175
+ print(LangSegmenter.getTexts(text))