admin117
commited on
Commit
·
bd3dc09
1
Parent(s):
86970b7
Initial commit without LFS files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +3 -0
- AR/models/embedding.py +98 -0
- AR/models/structs.py +91 -0
- AR/models/t2s_model_abc.py +538 -0
- AR/models/t2s_model_flash_attn.py +408 -0
- configs/s1.yaml +31 -0
- configs/s1big.yaml +31 -0
- configs/s1big2.yaml +31 -0
- configs/s1longer-v2.yaml +31 -0
- configs/s1longer.yaml +31 -0
- configs/s1mq.yaml +77 -0
- configs/s2.json +90 -0
- configs/train.yaml +32 -0
- download.py +5 -0
- eres2net/ERes2Net.py +260 -0
- eres2net/ERes2NetV2.py +292 -0
- eres2net/ERes2Net_huge.py +286 -0
- eres2net/fusion.py +29 -0
- eres2net/kaldi.py +819 -0
- eres2net/pooling_layers.py +104 -0
- feature_extractor/__init__.py +6 -0
- feature_extractor/cnhubert.py +109 -0
- feature_extractor/whisper_enc.py +25 -0
- inference_webui.py +867 -0
- module/__init__.py +0 -0
- module/attentions.py +709 -0
- module/attentions_onnx.py +354 -0
- module/commons.py +189 -0
- module/core_vq.py +383 -0
- module/data_utils.py +332 -0
- module/losses.py +73 -0
- module/mel_processing.py +153 -0
- module/models.py +1040 -0
- module/models_onnx.py +918 -0
- module/modules.py +923 -0
- module/mrte_model.py +192 -0
- module/quantize.py +119 -0
- module/transforms.py +209 -0
- packages.txt +1 -0
- pre-requirements.txt +2 -0
- pretrained_models/chinese-hubert-base/config.json +72 -0
- pretrained_models/chinese-hubert-base/preprocessor_config.json +9 -0
- pretrained_models/chinese-roberta-wwm-ext-large/config.json +34 -0
- pretrained_models/chinese-roberta-wwm-ext-large/tokenizer.json +0 -0
- process_ckpt.py +31 -0
- requirements.txt +36 -0
- sv.py +24 -0
- text/.gitignore +3 -0
- text/LangSegmenter/__init__.py +1 -0
- text/LangSegmenter/langsegmenter.py +175 -0
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.pickle
|
| 2 |
+
text/ja_userdic/user.dict
|
| 3 |
+
text/ja_userdic/userdict.csv
|
AR/models/embedding.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TokenEmbedding(nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
embedding_dim: int,
|
| 12 |
+
vocab_size: int,
|
| 13 |
+
dropout: float = 0.0,
|
| 14 |
+
):
|
| 15 |
+
super().__init__()
|
| 16 |
+
|
| 17 |
+
self.vocab_size = vocab_size
|
| 18 |
+
self.embedding_dim = embedding_dim
|
| 19 |
+
|
| 20 |
+
self.dropout = torch.nn.Dropout(p=dropout)
|
| 21 |
+
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
|
| 22 |
+
|
| 23 |
+
@property
|
| 24 |
+
def weight(self) -> torch.Tensor:
|
| 25 |
+
return self.word_embeddings.weight
|
| 26 |
+
|
| 27 |
+
def embedding(self, index: int) -> torch.Tensor:
|
| 28 |
+
return self.word_embeddings.weight[index : index + 1]
|
| 29 |
+
|
| 30 |
+
def forward(self, x: torch.Tensor):
|
| 31 |
+
x = self.word_embeddings(x)
|
| 32 |
+
x = self.dropout(x)
|
| 33 |
+
return x
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SinePositionalEmbeddingNested(nn.Module):
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
embedding_dim: int,
|
| 40 |
+
dropout: float = 0.0,
|
| 41 |
+
scale: bool = False,
|
| 42 |
+
alpha: bool = False,
|
| 43 |
+
max_batch_size: int = 20,
|
| 44 |
+
max_seq_len: int = 2500,
|
| 45 |
+
):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.embedding_dim = embedding_dim
|
| 48 |
+
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
|
| 49 |
+
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
|
| 50 |
+
self.dropout = torch.nn.Dropout(p=dropout)
|
| 51 |
+
self.max_batch_size = max_batch_size
|
| 52 |
+
self.max_seq_len = max_seq_len
|
| 53 |
+
|
| 54 |
+
self.reverse = False
|
| 55 |
+
self.register_buffer("pe", torch.zeros(max_batch_size, max_seq_len, embedding_dim), persistent=False)
|
| 56 |
+
self.pe: torch.Tensor
|
| 57 |
+
self.compute_pe()
|
| 58 |
+
|
| 59 |
+
def compute_pe(self):
|
| 60 |
+
"""Reset the positional encodings."""
|
| 61 |
+
if self.reverse:
|
| 62 |
+
position = torch.arange(self.max_seq_len - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
|
| 63 |
+
else:
|
| 64 |
+
position = torch.arange(self.max_seq_len, dtype=torch.float32).unsqueeze(1)
|
| 65 |
+
div_term = torch.exp(
|
| 66 |
+
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
|
| 67 |
+
)
|
| 68 |
+
pe = self.pe
|
| 69 |
+
pe[:, :, 0::2] = torch.sin(position * div_term)
|
| 70 |
+
pe[:, :, 1::2] = torch.cos(position * div_term)
|
| 71 |
+
|
| 72 |
+
def forward(self, input_pos: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
"""
|
| 74 |
+
Args:
|
| 75 |
+
input_pos (Tensor): [batch_size, ]
|
| 76 |
+
x (Tensor): [batch_size, 1, embed_dim]
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
embedded_x (Tensor): [batch_size, 1, embed_dim]
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
batch_size = x.shape[0]
|
| 83 |
+
pe_values = self.pe[torch.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
|
| 84 |
+
|
| 85 |
+
return x * self.x_scale + self.alpha * pe_values.unsqueeze(1) # (batch_size, 1, embed_dim)
|
| 86 |
+
|
| 87 |
+
def prefill(self, x: torch.Tensor) -> torch.Tensor:
|
| 88 |
+
"""
|
| 89 |
+
Args:
|
| 90 |
+
x (Tensor): Nested Seqlen [batch_size, seq_len, embed_dim]
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
embedded_x (Tensor): Nested Seqlen [batch_size, seq_len, embed_dim]
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
input_pos: torch.Tensor = torch.tensor([i.shape[0] for i in x.unbind()])
|
| 97 |
+
pe_values = torch.nested.nested_tensor([self.pe[i, : input_pos[i], :] for i in range(input_pos.size(0))])
|
| 98 |
+
return x * self.x_scale + self.alpha.item() * pe_values
|
AR/models/structs.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Modified From https://github.com/XXXXRT666/GPT-SoVITS
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import List, Literal, MutableSequence, Optional
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from AR.models.t2s_model_abc import KVCacheABC, Sampler, T2SDecoderABC
|
| 13 |
+
|
| 14 |
+
Tensor = torch.Tensor
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class T2SResult:
|
| 19 |
+
result: List[Tensor] | None = None
|
| 20 |
+
infer_speed: float = 0.0
|
| 21 |
+
status: Literal["Success", "Error"] = "Success"
|
| 22 |
+
exception: Optional[Exception] = None
|
| 23 |
+
traceback: Optional[str] = None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class T2SRequest:
|
| 28 |
+
x: List[torch.Tensor]
|
| 29 |
+
x_lens: Tensor
|
| 30 |
+
prompts: torch.Tensor
|
| 31 |
+
bert_feature: List[Tensor]
|
| 32 |
+
valid_length: int
|
| 33 |
+
top_k: int = 5
|
| 34 |
+
top_p: float = 1
|
| 35 |
+
early_stop_num: int = -1
|
| 36 |
+
temperature: float = 1.0
|
| 37 |
+
repetition_penalty: float = 1.35
|
| 38 |
+
use_cuda_graph: bool = False
|
| 39 |
+
debug: bool = False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class T2SSession:
|
| 43 |
+
def __init__(self, decoder: T2SDecoderABC, request: T2SRequest, device: torch.device, dtype: torch.dtype):
|
| 44 |
+
with device:
|
| 45 |
+
self.decoder = decoder
|
| 46 |
+
self.request = request
|
| 47 |
+
self.device = device
|
| 48 |
+
self.dtype = dtype
|
| 49 |
+
|
| 50 |
+
bsz = len(request.x)
|
| 51 |
+
y_len = request.prompts.size(-1)
|
| 52 |
+
self.bsz = bsz
|
| 53 |
+
self.y_len = y_len
|
| 54 |
+
|
| 55 |
+
# Cache
|
| 56 |
+
self.kv_cache: MutableSequence[KVCacheABC]
|
| 57 |
+
self.sampler = Sampler(bsz, decoder.vocab_size)
|
| 58 |
+
|
| 59 |
+
# Forward args
|
| 60 |
+
self.x = request.x
|
| 61 |
+
self.x_lens = request.x_lens.to(torch.int32)
|
| 62 |
+
self.y = request.prompts
|
| 63 |
+
self.bert_feature = request.bert_feature
|
| 64 |
+
|
| 65 |
+
self.prefill_len = self.x_lens + self.y.size(1)
|
| 66 |
+
|
| 67 |
+
self.input_pos = torch.zeros_like(self.prefill_len)
|
| 68 |
+
self.input_pos.add_(self.prefill_len)
|
| 69 |
+
|
| 70 |
+
# CUDA Graph
|
| 71 |
+
self.graph: Optional[torch.cuda.CUDAGraph] = None
|
| 72 |
+
self.xy_pos_: Tensor
|
| 73 |
+
self.xy_dec_: Tensor
|
| 74 |
+
|
| 75 |
+
# EOS
|
| 76 |
+
self.completed = torch.Tensor([False] * len(self.x)).bool().to(device)
|
| 77 |
+
self.y_results: List[Tensor] = [None] * len(self.x) # type: ignore
|
| 78 |
+
|
| 79 |
+
self.xy_pos = decoder.embed(self.x, self.y, self.bert_feature)
|
| 80 |
+
|
| 81 |
+
attn_mask = []
|
| 82 |
+
for bs in range(bsz):
|
| 83 |
+
pos = int(self.x_lens[bs].item())
|
| 84 |
+
mask = torch.zeros(pos + y_len, pos + y_len).bool()
|
| 85 |
+
mask[:, :pos].fill_(True)
|
| 86 |
+
if y_len > 0:
|
| 87 |
+
mask[-y_len:, -y_len:] = ~torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1)
|
| 88 |
+
attn_mask.append(mask)
|
| 89 |
+
self.attn_mask_nested = torch.nested.nested_tensor(attn_mask)
|
| 90 |
+
|
| 91 |
+
self.id: int = -1
|
AR/models/t2s_model_abc.py
ADDED
|
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Modified From https://github.com/XXXXRT666/GPT-SoVITS
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import random
|
| 9 |
+
from abc import ABC, abstractmethod
|
| 10 |
+
from contextlib import nullcontext
|
| 11 |
+
from typing import Any, Dict, List, MutableSequence, Tuple, Type
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch._inductor.config
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from torch.cuda.graphs import CUDAGraph
|
| 18 |
+
from torch.profiler import ProfilerAction, tensorboard_trace_handler
|
| 19 |
+
|
| 20 |
+
from AR.models.embedding import (
|
| 21 |
+
SinePositionalEmbeddingNested as SinePositionalEmbedding,
|
| 22 |
+
)
|
| 23 |
+
from AR.models.embedding import TokenEmbedding
|
| 24 |
+
|
| 25 |
+
Tensor = torch.Tensor
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Sampler(nn.Module):
|
| 29 |
+
def __init__(self, batch_size: int, vocab_size: int) -> None:
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.batch_size = batch_size
|
| 32 |
+
|
| 33 |
+
# @torch.jit.script
|
| 34 |
+
def sample(
|
| 35 |
+
self,
|
| 36 |
+
logits: Tensor,
|
| 37 |
+
previous_tokens: Tensor,
|
| 38 |
+
temperature: float,
|
| 39 |
+
top_k: int,
|
| 40 |
+
top_p: float,
|
| 41 |
+
repetition_penalty: float,
|
| 42 |
+
) -> Tensor:
|
| 43 |
+
previous_tokens = previous_tokens.long()
|
| 44 |
+
score = torch.gather(logits, dim=1, index=previous_tokens)
|
| 45 |
+
score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
|
| 46 |
+
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
| 47 |
+
|
| 48 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 49 |
+
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
| 50 |
+
sorted_indices_to_remove = cum_probs > top_p
|
| 51 |
+
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
| 52 |
+
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
| 53 |
+
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
| 54 |
+
|
| 55 |
+
logits = logits / max(temperature, 1e-5)
|
| 56 |
+
|
| 57 |
+
v, _ = torch.topk(logits, top_k)
|
| 58 |
+
pivot = v[:, -1].unsqueeze(-1)
|
| 59 |
+
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
| 60 |
+
|
| 61 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 62 |
+
q = torch.empty_like(probs).exponential_(1.0)
|
| 63 |
+
idx_next = torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int32)
|
| 64 |
+
|
| 65 |
+
return idx_next
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class KVCacheABC(ABC, nn.Module):
|
| 69 |
+
def __init__(self, *args, **kwds) -> None:
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.k_cache: Tensor
|
| 72 |
+
self.v_cache: Tensor
|
| 73 |
+
self.n_head: int
|
| 74 |
+
self.head_dim: int
|
| 75 |
+
self.batch_size: int
|
| 76 |
+
self.max_seq_length: int
|
| 77 |
+
|
| 78 |
+
def empty(self):
|
| 79 |
+
self.k_cache.zero_()
|
| 80 |
+
self.v_cache.zero_()
|
| 81 |
+
|
| 82 |
+
@abstractmethod
|
| 83 |
+
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor, *args, **kwds) -> Tuple[Tensor, Tensor]: ...
|
| 84 |
+
|
| 85 |
+
@abstractmethod
|
| 86 |
+
def prefill_kv(self, k_val: Tensor, v_val: Tensor, bs: int) -> None: ...
|
| 87 |
+
|
| 88 |
+
def sync_cache(self, kv_cache: KVCacheABC):
|
| 89 |
+
self.k_cache.copy_(kv_cache.k_cache)
|
| 90 |
+
self.v_cache.copy_(kv_cache.v_cache)
|
| 91 |
+
|
| 92 |
+
def forward(self):
|
| 93 |
+
raise NotImplementedError()
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class KVCacheNHD(KVCacheABC):
|
| 97 |
+
def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
|
| 98 |
+
super().__init__()
|
| 99 |
+
assert batch_size > 0
|
| 100 |
+
cache_shape = (batch_size, max_seq_length, n_heads, head_dim)
|
| 101 |
+
self.n_head = n_heads
|
| 102 |
+
self.head_dim = head_dim
|
| 103 |
+
self.batch_size = batch_size
|
| 104 |
+
self.max_seq_length = max_seq_length
|
| 105 |
+
|
| 106 |
+
self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
|
| 107 |
+
self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
|
| 108 |
+
|
| 109 |
+
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
|
| 110 |
+
# input_pos: [B, ], k_val: [B, 1, H, D]
|
| 111 |
+
|
| 112 |
+
index = (
|
| 113 |
+
(input_pos - 1)
|
| 114 |
+
.unsqueeze(-1)
|
| 115 |
+
.unsqueeze(-1)
|
| 116 |
+
.unsqueeze(-1)
|
| 117 |
+
.expand(
|
| 118 |
+
-1,
|
| 119 |
+
-1,
|
| 120 |
+
self.n_head,
|
| 121 |
+
self.head_dim,
|
| 122 |
+
)
|
| 123 |
+
.to(torch.int64)
|
| 124 |
+
) # (bs, 1, num_head, head_dim)
|
| 125 |
+
|
| 126 |
+
k_out = self.k_cache
|
| 127 |
+
v_out = self.v_cache
|
| 128 |
+
k_out.scatter_(1, index, k_val)
|
| 129 |
+
v_out.scatter_(1, index, v_val)
|
| 130 |
+
|
| 131 |
+
return k_out, v_out
|
| 132 |
+
|
| 133 |
+
def empty(self):
|
| 134 |
+
self.k_cache.zero_()
|
| 135 |
+
self.v_cache.zero_()
|
| 136 |
+
|
| 137 |
+
def prefill_kv(self, k_val: Tensor, v_val: Tensor, bs: int):
|
| 138 |
+
# input_pos: int, k_val: [B, S, H, D]
|
| 139 |
+
|
| 140 |
+
self.k_cache[[bs], : k_val.shape[1]] = k_val
|
| 141 |
+
self.v_cache[[bs], : v_val.shape[1]] = v_val
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class KVCacheHND(KVCacheABC):
|
| 145 |
+
def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
|
| 146 |
+
super().__init__()
|
| 147 |
+
assert batch_size > 0
|
| 148 |
+
cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
|
| 149 |
+
self.n_head = n_heads
|
| 150 |
+
self.head_dim = head_dim
|
| 151 |
+
self.batch_size = batch_size
|
| 152 |
+
self.max_seq_length = max_seq_length
|
| 153 |
+
|
| 154 |
+
self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
|
| 155 |
+
self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
|
| 156 |
+
|
| 157 |
+
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
|
| 158 |
+
# input_pos: [B, ], k_val: [B, H, 1, D]
|
| 159 |
+
|
| 160 |
+
index = (
|
| 161 |
+
(input_pos - 1)
|
| 162 |
+
.unsqueeze(-1)
|
| 163 |
+
.unsqueeze(-1)
|
| 164 |
+
.unsqueeze(-1)
|
| 165 |
+
.expand(
|
| 166 |
+
-1,
|
| 167 |
+
self.n_head,
|
| 168 |
+
-1,
|
| 169 |
+
self.head_dim,
|
| 170 |
+
)
|
| 171 |
+
.to(torch.int64)
|
| 172 |
+
) # (bs, num_head, 1, head_dim)
|
| 173 |
+
|
| 174 |
+
k_out = self.k_cache
|
| 175 |
+
v_out = self.v_cache
|
| 176 |
+
k_out.scatter_(2, index, k_val)
|
| 177 |
+
v_out.scatter_(2, index, v_val)
|
| 178 |
+
|
| 179 |
+
return k_out, v_out
|
| 180 |
+
|
| 181 |
+
def empty(self):
|
| 182 |
+
self.k_cache.zero_()
|
| 183 |
+
self.v_cache.zero_()
|
| 184 |
+
|
| 185 |
+
def prefill_kv(self, k_val: Tensor, v_val: Tensor, bs: int):
|
| 186 |
+
# input_pos: int, k_val: [B, S, H, D]
|
| 187 |
+
|
| 188 |
+
self.k_cache[[bs], :, : k_val.shape[1]] = k_val.transpose(1, 2)
|
| 189 |
+
self.v_cache[[bs], :, : v_val.shape[1]] = v_val.transpose(1, 2)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class AttentionABC(ABC, nn.Module):
|
| 193 |
+
def __init__(self):
|
| 194 |
+
super().__init__()
|
| 195 |
+
self.n_head: int
|
| 196 |
+
self.hidden_dim: int
|
| 197 |
+
self.head_dim: int
|
| 198 |
+
|
| 199 |
+
# key, query, value projections for all heads, but in a batch
|
| 200 |
+
self.in_proj: nn.Linear
|
| 201 |
+
self.out_proj: nn.Linear
|
| 202 |
+
|
| 203 |
+
self.dropout = nn.Dropout(0.1)
|
| 204 |
+
|
| 205 |
+
self._register_load_state_dict_pre_hook(self.load_hook)
|
| 206 |
+
|
| 207 |
+
def load_hook(self, state_dict: dict, prefix, *args):
|
| 208 |
+
keys_to_modify = [key for key in state_dict if "in_proj_" in key]
|
| 209 |
+
for key in keys_to_modify:
|
| 210 |
+
new_key = key.replace("in_proj_", "in_proj.") # in_proj_ -> in_proj.
|
| 211 |
+
state_dict[new_key] = state_dict.pop(key)
|
| 212 |
+
|
| 213 |
+
@abstractmethod
|
| 214 |
+
def forward(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheABC, *args, **kwds) -> Tensor: ...
|
| 215 |
+
|
| 216 |
+
def prefill(self, x: Tensor, mask: Tensor, kv_cache: KVCacheABC) -> Tensor:
|
| 217 |
+
bsz = x.size(0)
|
| 218 |
+
|
| 219 |
+
outputs = []
|
| 220 |
+
|
| 221 |
+
for bs in range(bsz):
|
| 222 |
+
x_b = x[bs].unsqueeze(0)
|
| 223 |
+
|
| 224 |
+
q, k, v = self.in_proj.forward(x_b.unsqueeze(0)).chunk(3, dim=-1)
|
| 225 |
+
|
| 226 |
+
q = q.contiguous().view(1, -1, self.n_head, self.head_dim)
|
| 227 |
+
k = k.contiguous().view(1, -1, self.n_head, self.head_dim)
|
| 228 |
+
v = v.contiguous().view(1, -1, self.n_head, self.head_dim)
|
| 229 |
+
|
| 230 |
+
kv_cache.prefill_kv(k, v, bs)
|
| 231 |
+
|
| 232 |
+
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
| 233 |
+
|
| 234 |
+
attn_mask = mask[bs].unsqueeze(0).unsqueeze(0).expand(1, self.n_head, -1, -1)
|
| 235 |
+
|
| 236 |
+
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
| 237 |
+
|
| 238 |
+
attn = self.dropout.forward(attn)
|
| 239 |
+
|
| 240 |
+
attn = attn.transpose(1, 2).contiguous().view(1, -1, self.hidden_dim)
|
| 241 |
+
|
| 242 |
+
output = self.out_proj.forward(attn)
|
| 243 |
+
|
| 244 |
+
outputs.append(output.squeeze(0))
|
| 245 |
+
|
| 246 |
+
return torch.nested.nested_tensor(outputs)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class FeedForward(nn.Module):
|
| 250 |
+
def __init__(self, dim: int, hidden_dim: int) -> None:
|
| 251 |
+
super().__init__()
|
| 252 |
+
self.linear1 = nn.Linear(dim, hidden_dim, bias=True)
|
| 253 |
+
self.linear2 = nn.Linear(hidden_dim, dim, bias=True)
|
| 254 |
+
self.dropout = nn.Dropout(0.1)
|
| 255 |
+
|
| 256 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 257 |
+
return self.dropout.forward(self.linear2(self.dropout.forward(F.relu(self.linear1(x)))))
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class TransformerBlockABC(ABC, nn.Module):
|
| 261 |
+
def __init__(self) -> None:
|
| 262 |
+
super().__init__()
|
| 263 |
+
self.hidden_dim: int
|
| 264 |
+
self.attention: AttentionABC
|
| 265 |
+
self.feed_forward: FeedForward
|
| 266 |
+
self.attention_norm: nn.LayerNorm
|
| 267 |
+
self.ffn_norm: nn.LayerNorm
|
| 268 |
+
self.dropout = nn.Dropout(0.1)
|
| 269 |
+
|
| 270 |
+
self._register_load_state_dict_pre_hook(self.load_hook)
|
| 271 |
+
|
| 272 |
+
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
|
| 273 |
+
for key in list(state_dict.keys()):
|
| 274 |
+
new_key = (
|
| 275 |
+
key.replace("self_attn", "attention")
|
| 276 |
+
.replace("linear", "feed_forward.linear")
|
| 277 |
+
.replace("norm1", "attention_norm")
|
| 278 |
+
.replace("norm2", "ffn_norm")
|
| 279 |
+
)
|
| 280 |
+
state_dict[new_key] = state_dict.pop(key)
|
| 281 |
+
|
| 282 |
+
def forward(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheABC, *args, **kwds) -> Tensor:
|
| 283 |
+
h = self.attention_norm.forward(
|
| 284 |
+
x
|
| 285 |
+
+ self.dropout.forward(
|
| 286 |
+
self.attention.forward(
|
| 287 |
+
x,
|
| 288 |
+
input_pos,
|
| 289 |
+
kv_cache,
|
| 290 |
+
*args,
|
| 291 |
+
**kwds,
|
| 292 |
+
)
|
| 293 |
+
)
|
| 294 |
+
)
|
| 295 |
+
out = self.ffn_norm.forward(h + self.feed_forward.forward(h))
|
| 296 |
+
return out
|
| 297 |
+
|
| 298 |
+
def prefill(self, x: Tensor, mask: Tensor, kv_cache: KVCacheABC) -> Tensor:
|
| 299 |
+
h = self.attention_norm.forward(
|
| 300 |
+
x
|
| 301 |
+
+ self.dropout.forward(
|
| 302 |
+
self.attention.prefill(
|
| 303 |
+
x,
|
| 304 |
+
mask,
|
| 305 |
+
kv_cache,
|
| 306 |
+
)
|
| 307 |
+
)
|
| 308 |
+
)
|
| 309 |
+
out = self.ffn_norm.forward(h + self.feed_forward.forward(h))
|
| 310 |
+
return out
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class TransformerDecoderABC(ABC, nn.Module):
|
| 314 |
+
def __init__(self) -> None:
|
| 315 |
+
super().__init__()
|
| 316 |
+
|
| 317 |
+
self.hidden_dim: int
|
| 318 |
+
self.n_head: int
|
| 319 |
+
self.head_dim: int
|
| 320 |
+
self.vocab_size: int
|
| 321 |
+
self.n_layer: int
|
| 322 |
+
|
| 323 |
+
self.layers: MutableSequence[TransformerBlockABC]
|
| 324 |
+
|
| 325 |
+
self.max_seq_length: int
|
| 326 |
+
self.max_batch_size: int
|
| 327 |
+
|
| 328 |
+
self.input_pos: Tensor
|
| 329 |
+
self.xy_pos: Tensor
|
| 330 |
+
self.xy_dec: Tensor
|
| 331 |
+
|
| 332 |
+
def forward(self, input_pos: Tensor, x: Tensor, kv_caches: MutableSequence[KVCacheABC], *args, **kwds):
|
| 333 |
+
for layer, kv_cache in zip(self.layers, kv_caches):
|
| 334 |
+
x = layer.forward(x, input_pos, kv_cache, *args, **kwds)
|
| 335 |
+
return x
|
| 336 |
+
|
| 337 |
+
def prefill(self, x: Tensor, mask: Tensor, kv_caches: MutableSequence[KVCacheABC]):
|
| 338 |
+
for layer, kv_cache in zip(self.layers, kv_caches):
|
| 339 |
+
x = layer.prefill(x, mask, kv_cache)
|
| 340 |
+
return x
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
class T2SDecoderABC(ABC, nn.Module):
|
| 344 |
+
def __init__(self) -> None:
|
| 345 |
+
super().__init__()
|
| 346 |
+
|
| 347 |
+
self.n_layer: int
|
| 348 |
+
self.hidden_dim: int
|
| 349 |
+
self.n_head: int
|
| 350 |
+
|
| 351 |
+
self.head_dim: int
|
| 352 |
+
self.embedding_dim: int
|
| 353 |
+
self.vocab_size: int
|
| 354 |
+
self.phoneme_vocab_size: int
|
| 355 |
+
self.p_dropout: float
|
| 356 |
+
self.max_seq_length: int
|
| 357 |
+
self.max_batch_size: int
|
| 358 |
+
self.EOS: int
|
| 359 |
+
|
| 360 |
+
self.bert_proj: nn.Linear
|
| 361 |
+
self.ar_text_embedding: TokenEmbedding
|
| 362 |
+
self.ar_text_position: SinePositionalEmbedding
|
| 363 |
+
self.ar_audio_embedding: TokenEmbedding
|
| 364 |
+
self.ar_audio_position: SinePositionalEmbedding
|
| 365 |
+
self.ar_predict_layer: nn.Linear
|
| 366 |
+
self.h: TransformerDecoderABC
|
| 367 |
+
|
| 368 |
+
self.kv_class: Type[KVCacheNHD] | Type[KVCacheHND]
|
| 369 |
+
|
| 370 |
+
self.GraphCache: CUDAGraphCacheABC | None
|
| 371 |
+
|
| 372 |
+
self._register_load_state_dict_pre_hook(self.load_hook)
|
| 373 |
+
|
| 374 |
+
def load_hook(self, state_dict, prefix, *args):
|
| 375 |
+
model_keys = [key for key in state_dict if key.startswith("model.")]
|
| 376 |
+
for key in model_keys:
|
| 377 |
+
new_key = key[len("model.") :]
|
| 378 |
+
state_dict[new_key] = state_dict.pop(key)
|
| 379 |
+
|
| 380 |
+
def init_cache(self, bsz: int = 0) -> MutableSequence[KVCacheABC]:
|
| 381 |
+
bsz = bsz or self.h.max_batch_size
|
| 382 |
+
assert bsz <= self.h.max_batch_size
|
| 383 |
+
seq_lens = self.h.max_seq_length
|
| 384 |
+
device = self.bert_proj.bias.device
|
| 385 |
+
dtype = self.bert_proj.bias.dtype
|
| 386 |
+
kvclass = self.kv_class
|
| 387 |
+
return nn.ModuleList(
|
| 388 |
+
[kvclass(bsz, seq_lens, self.n_head, self.head_dim) for _ in range(self.n_layer)],
|
| 389 |
+
).to(device, dtype) # type: ignore
|
| 390 |
+
|
| 391 |
+
@abstractmethod
|
| 392 |
+
def embed(self, x: List[torch.Tensor], y: torch.Tensor, bert_features: List[Tensor]) -> Tensor: ...
|
| 393 |
+
|
| 394 |
+
def compile(self, *args, **kwds):
|
| 395 |
+
torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True
|
| 396 |
+
torch._inductor.config.coordinate_descent_tuning = True
|
| 397 |
+
torch._inductor.config.triton.unique_kernel_names = True
|
| 398 |
+
# Experimental features to reduce compilation times, will be on by default in future
|
| 399 |
+
torch._inductor.config.fx_graph_cache = True
|
| 400 |
+
torch._inductor.config.triton.cudagraph_trees = True
|
| 401 |
+
torch._inductor.config.triton.cudagraph_support_input_mutation = True
|
| 402 |
+
self.h.compile(fullgraph=True, mode="reduce-overhead")
|
| 403 |
+
|
| 404 |
+
def capture(self, input_pos: Tensor, x: Tensor, x_dec: Tensor, *args, **kwds) -> CUDAGraph:
|
| 405 |
+
assert torch.cuda.is_available()
|
| 406 |
+
s = torch.cuda.Stream()
|
| 407 |
+
s.wait_stream(torch.cuda.current_stream())
|
| 408 |
+
|
| 409 |
+
graph = torch.cuda.CUDAGraph()
|
| 410 |
+
|
| 411 |
+
with torch.cuda.stream(s): # type: ignore
|
| 412 |
+
for _ in range(5):
|
| 413 |
+
self.h.forward(input_pos, x, *args, **kwds)
|
| 414 |
+
torch.cuda.current_stream().wait_stream(s)
|
| 415 |
+
|
| 416 |
+
with torch.cuda.graph(graph):
|
| 417 |
+
x_dec.copy_(self.h.forward(input_pos, x, *args, **kwds))
|
| 418 |
+
torch.cuda.synchronize()
|
| 419 |
+
|
| 420 |
+
return graph
|
| 421 |
+
|
| 422 |
+
@abstractmethod
|
| 423 |
+
def pre_forward(self, session: Any) -> Tuple[List, Dict]: ...
|
| 424 |
+
|
| 425 |
+
@abstractmethod
|
| 426 |
+
def post_forward(self, idx: int, session: Any) -> None: ...
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
class CUDAGraphCacheABC(ABC):
|
| 430 |
+
def __init__(
|
| 431 |
+
self,
|
| 432 |
+
decoder: T2SDecoderABC,
|
| 433 |
+
device: torch.device = torch.device("cpu"),
|
| 434 |
+
dtype: torch.dtype = torch.float32,
|
| 435 |
+
) -> None:
|
| 436 |
+
assert torch.cuda.is_available()
|
| 437 |
+
|
| 438 |
+
self.assigned: bool = False
|
| 439 |
+
|
| 440 |
+
self.decoder: T2SDecoderABC = decoder
|
| 441 |
+
self.kv_cache: MutableSequence[KVCacheABC] = decoder.init_cache(1)
|
| 442 |
+
self.xy_pos = torch.rand((1, 1, decoder.embedding_dim), device=device).to(dtype)
|
| 443 |
+
self.xy_dec = torch.rand((1, 1, decoder.embedding_dim), device=device).to(dtype)
|
| 444 |
+
self.input_pos = torch.tensor([10]).int().cuda()
|
| 445 |
+
self.graph: torch.cuda.CUDAGraph | None = None
|
| 446 |
+
|
| 447 |
+
self.id: int = random.randint(1, 2**32 - 1)
|
| 448 |
+
|
| 449 |
+
def assign_graph(self, session: Any):
|
| 450 |
+
if self.graph is None:
|
| 451 |
+
args, kwds = self.decoder.pre_forward(session)
|
| 452 |
+
graph = self.decoder.capture(
|
| 453 |
+
self.input_pos, self.xy_pos, self.xy_dec, kv_caches=self.kv_cache, *args, **kwds
|
| 454 |
+
)
|
| 455 |
+
self.graph = graph
|
| 456 |
+
|
| 457 |
+
if self.assigned is False:
|
| 458 |
+
self.get_cache_graph(session)
|
| 459 |
+
session.id = self.id
|
| 460 |
+
self.assigned = True
|
| 461 |
+
else:
|
| 462 |
+
self.capture_new_graph(session)
|
| 463 |
+
|
| 464 |
+
@abstractmethod
|
| 465 |
+
def release_graph(self, session: Any): ...
|
| 466 |
+
|
| 467 |
+
@abstractmethod
|
| 468 |
+
def get_cache_graph(self, session: Any):
|
| 469 |
+
pass
|
| 470 |
+
|
| 471 |
+
@abstractmethod
|
| 472 |
+
def capture_new_graph(self, session: Any):
|
| 473 |
+
pass
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
class TorchProfiler:
|
| 477 |
+
def __init__(self, debug: bool, log_dir: str = "./profiler") -> None:
|
| 478 |
+
self.debug = debug
|
| 479 |
+
self.log_dir = log_dir
|
| 480 |
+
self.__profiler: torch.profiler.profile
|
| 481 |
+
|
| 482 |
+
if self.debug and not os.path.exists(self.log_dir):
|
| 483 |
+
os.makedirs(self.log_dir)
|
| 484 |
+
|
| 485 |
+
self.tensorboard_handler = tensorboard_trace_handler(self.log_dir)
|
| 486 |
+
|
| 487 |
+
def profiler_callback(self, prof: torch.profiler.profile):
|
| 488 |
+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))
|
| 489 |
+
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=30))
|
| 490 |
+
self.tensorboard_handler(prof)
|
| 491 |
+
|
| 492 |
+
@staticmethod
|
| 493 |
+
def three_step_schedule(step: int) -> ProfilerAction:
|
| 494 |
+
if step == 0:
|
| 495 |
+
return ProfilerAction.NONE
|
| 496 |
+
elif step == 1:
|
| 497 |
+
return ProfilerAction.RECORD
|
| 498 |
+
elif step == 2:
|
| 499 |
+
return ProfilerAction.RECORD_AND_SAVE
|
| 500 |
+
else:
|
| 501 |
+
return ProfilerAction.NONE
|
| 502 |
+
|
| 503 |
+
def start(self):
|
| 504 |
+
if not self.debug:
|
| 505 |
+
return
|
| 506 |
+
assert self.__profiler is not None
|
| 507 |
+
self.__profiler.step()
|
| 508 |
+
|
| 509 |
+
def end(self):
|
| 510 |
+
if not self.debug:
|
| 511 |
+
return
|
| 512 |
+
assert self.__profiler is not None
|
| 513 |
+
self.__profiler.step()
|
| 514 |
+
|
| 515 |
+
def profiler(self):
|
| 516 |
+
if self.debug:
|
| 517 |
+
activities_list = [torch.profiler.ProfilerActivity.CPU]
|
| 518 |
+
if torch.cuda.is_available():
|
| 519 |
+
activities_list.append(torch.profiler.ProfilerActivity.CUDA)
|
| 520 |
+
|
| 521 |
+
self.__profiler = torch.profiler.profile(
|
| 522 |
+
activities=activities_list,
|
| 523 |
+
record_shapes=True,
|
| 524 |
+
with_stack=True,
|
| 525 |
+
with_modules=True,
|
| 526 |
+
profile_memory=True,
|
| 527 |
+
schedule=self.three_step_schedule,
|
| 528 |
+
on_trace_ready=self.profiler_callback,
|
| 529 |
+
)
|
| 530 |
+
return self.__profiler
|
| 531 |
+
else:
|
| 532 |
+
return nullcontext()
|
| 533 |
+
|
| 534 |
+
def record(self, func_name: str):
|
| 535 |
+
if self.debug:
|
| 536 |
+
return torch.profiler.record_function(func_name)
|
| 537 |
+
else:
|
| 538 |
+
return nullcontext()
|
AR/models/t2s_model_flash_attn.py
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Modified From https://github.com/XXXXRT666/GPT-SoVITS
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import gc
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
import traceback
|
| 9 |
+
from typing import Dict, List, Tuple
|
| 10 |
+
|
| 11 |
+
import flash_attn # type: ignore
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
from AR.models.embedding import (
|
| 17 |
+
SinePositionalEmbeddingNested as SinePositionalEmbedding,
|
| 18 |
+
)
|
| 19 |
+
from AR.models.embedding import TokenEmbedding
|
| 20 |
+
from AR.models.structs import T2SRequest, T2SResult, T2SSession
|
| 21 |
+
from AR.models.t2s_model_abc import (
|
| 22 |
+
AttentionABC,
|
| 23 |
+
CUDAGraphCacheABC,
|
| 24 |
+
FeedForward,
|
| 25 |
+
KVCacheABC,
|
| 26 |
+
KVCacheNHD,
|
| 27 |
+
T2SDecoderABC,
|
| 28 |
+
TorchProfiler,
|
| 29 |
+
TransformerBlockABC,
|
| 30 |
+
TransformerDecoderABC,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
Tensor = torch.Tensor
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Attention(AttentionABC):
|
| 37 |
+
def __init__(self, n_head: int, hidden_dim: int):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.n_head = n_head
|
| 40 |
+
self.hidden_dim = hidden_dim
|
| 41 |
+
assert hidden_dim % n_head == 0
|
| 42 |
+
self.head_dim = hidden_dim // n_head
|
| 43 |
+
|
| 44 |
+
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
|
| 45 |
+
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 46 |
+
|
| 47 |
+
def forward(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheABC, *args, **kwds) -> Tensor:
|
| 48 |
+
bsz, seqlen, _ = x.shape
|
| 49 |
+
|
| 50 |
+
q, k, v = self.in_proj.forward(x).chunk(3, dim=-1)
|
| 51 |
+
|
| 52 |
+
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
| 53 |
+
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
|
| 54 |
+
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
|
| 55 |
+
|
| 56 |
+
attn: Tensor = flash_attn.flash_attn_with_kvcache(
|
| 57 |
+
q, kv_cache.k_cache, kv_cache.v_cache, k, v, cache_seqlens=input_pos - 1
|
| 58 |
+
) # type: ignore
|
| 59 |
+
|
| 60 |
+
attn = self.dropout.forward(attn)
|
| 61 |
+
|
| 62 |
+
attn = attn.view(bsz, seqlen, self.hidden_dim)
|
| 63 |
+
|
| 64 |
+
attn = self.out_proj.forward(attn)
|
| 65 |
+
|
| 66 |
+
return attn
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class TransformerBlock(TransformerBlockABC):
|
| 70 |
+
def __init__(self, n_head, ffn_dim, hidden_dim) -> None:
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.hidden_dim = hidden_dim
|
| 73 |
+
self.attention = Attention(n_head, hidden_dim)
|
| 74 |
+
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
|
| 75 |
+
self.attention_norm = nn.LayerNorm([self.hidden_dim])
|
| 76 |
+
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class TransformerDecoder(TransformerDecoderABC):
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
hidden_dim,
|
| 83 |
+
n_layer,
|
| 84 |
+
n_head,
|
| 85 |
+
ffn_dim,
|
| 86 |
+
vocab_size,
|
| 87 |
+
max_seq_length,
|
| 88 |
+
max_batch_size,
|
| 89 |
+
) -> None:
|
| 90 |
+
super().__init__()
|
| 91 |
+
|
| 92 |
+
self.hidden_dim = hidden_dim
|
| 93 |
+
self.n_head = n_head
|
| 94 |
+
assert hidden_dim % n_head == 0
|
| 95 |
+
|
| 96 |
+
self.head_dim = hidden_dim // n_head
|
| 97 |
+
self.vocab_size = vocab_size
|
| 98 |
+
|
| 99 |
+
self.n_layer = n_layer
|
| 100 |
+
|
| 101 |
+
self.layers = nn.ModuleList( # type: ignore
|
| 102 |
+
TransformerBlock(n_head, ffn_dim, hidden_dim) for _ in range(n_layer)
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
self.max_seq_length: int = max_seq_length
|
| 106 |
+
self.max_batch_size: int = max_batch_size
|
| 107 |
+
|
| 108 |
+
self.setup_caches(self.max_batch_size, self.max_seq_length)
|
| 109 |
+
|
| 110 |
+
def setup_caches(self, max_batch_size=10, max_seq_length=2500):
|
| 111 |
+
self.max_seq_length = max_seq_length
|
| 112 |
+
self.max_batch_size = max_batch_size
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class T2SDecoder(T2SDecoderABC):
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
config,
|
| 119 |
+
*args,
|
| 120 |
+
norm_first=False,
|
| 121 |
+
max_seq_length=2500,
|
| 122 |
+
max_batch_size=10,
|
| 123 |
+
**kwds,
|
| 124 |
+
) -> None:
|
| 125 |
+
assert torch.cuda.is_available()
|
| 126 |
+
super().__init__()
|
| 127 |
+
|
| 128 |
+
hidden_dim = config["model"]["hidden_dim"]
|
| 129 |
+
embedding_dim = config["model"]["embedding_dim"]
|
| 130 |
+
n_head = config["model"]["head"]
|
| 131 |
+
n_layer = config["model"]["n_layer"]
|
| 132 |
+
vocab_size = config["model"]["vocab_size"]
|
| 133 |
+
phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
|
| 134 |
+
p_dropout = config["model"]["dropout"]
|
| 135 |
+
EOS = config["model"]["EOS"]
|
| 136 |
+
ffn_dim = hidden_dim * 4
|
| 137 |
+
self.norm_first = norm_first
|
| 138 |
+
|
| 139 |
+
self.n_layer = n_layer
|
| 140 |
+
self.hidden_dim = hidden_dim
|
| 141 |
+
self.n_head = n_head
|
| 142 |
+
assert hidden_dim % n_head == 0
|
| 143 |
+
|
| 144 |
+
self.head_dim = hidden_dim // n_head
|
| 145 |
+
self.embedding_dim = embedding_dim
|
| 146 |
+
self.vocab_size = vocab_size
|
| 147 |
+
self.phoneme_vocab_size = phoneme_vocab_size
|
| 148 |
+
self.p_dropout = p_dropout
|
| 149 |
+
self.max_seq_length = max_seq_length
|
| 150 |
+
self.max_batch_size = max_batch_size
|
| 151 |
+
self.EOS = EOS
|
| 152 |
+
assert self.EOS == self.vocab_size - 1
|
| 153 |
+
|
| 154 |
+
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
| 155 |
+
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size, self.p_dropout)
|
| 156 |
+
self.ar_text_position = SinePositionalEmbedding(
|
| 157 |
+
self.embedding_dim,
|
| 158 |
+
dropout=0.1,
|
| 159 |
+
scale=False,
|
| 160 |
+
alpha=True,
|
| 161 |
+
max_batch_size=max_batch_size,
|
| 162 |
+
max_seq_len=max_seq_length,
|
| 163 |
+
)
|
| 164 |
+
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size, self.p_dropout)
|
| 165 |
+
self.ar_audio_position = SinePositionalEmbedding(
|
| 166 |
+
self.embedding_dim,
|
| 167 |
+
dropout=0.1,
|
| 168 |
+
scale=False,
|
| 169 |
+
alpha=True,
|
| 170 |
+
max_batch_size=max_batch_size,
|
| 171 |
+
max_seq_len=max_seq_length,
|
| 172 |
+
)
|
| 173 |
+
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
|
| 174 |
+
self.h: TransformerDecoderABC = TransformerDecoder(
|
| 175 |
+
hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
self.kv_class = KVCacheNHD
|
| 179 |
+
self._register_load_state_dict_pre_hook(self.load_hook)
|
| 180 |
+
|
| 181 |
+
def embed(
|
| 182 |
+
self,
|
| 183 |
+
x: List[torch.Tensor],
|
| 184 |
+
y: torch.Tensor,
|
| 185 |
+
bert_features: List[torch.Tensor],
|
| 186 |
+
):
|
| 187 |
+
x_nested = torch.nested.nested_tensor(x)
|
| 188 |
+
assert x_nested.size(0) <= self.max_batch_size
|
| 189 |
+
bert_features_nested = torch.nested.nested_tensor(list(map(lambda x: x.transpose(0, 1), bert_features)))
|
| 190 |
+
|
| 191 |
+
x_emb = self.ar_text_embedding.forward(x_nested)
|
| 192 |
+
bert = self.bert_proj.forward(bert_features_nested)
|
| 193 |
+
x_emb = x_emb + bert
|
| 194 |
+
x_pos = self.ar_text_position.prefill(x_emb)
|
| 195 |
+
|
| 196 |
+
y_nested = torch.nested.nested_tensor(list(y.unbind(0)))
|
| 197 |
+
y_emb = self.ar_audio_embedding.forward(y_nested)
|
| 198 |
+
y_pos = self.ar_audio_position.prefill(y_emb)
|
| 199 |
+
|
| 200 |
+
xy_pos = torch.nested.nested_tensor([torch.cat([x_pos[i], y_pos[i]]) for i in range(len(x))])
|
| 201 |
+
return xy_pos
|
| 202 |
+
|
| 203 |
+
def post_forward(self, idx: int, session: T2SSession) -> None:
|
| 204 |
+
pass
|
| 205 |
+
|
| 206 |
+
def pre_forward(self, session: T2SSession) -> Tuple[List, Dict]:
|
| 207 |
+
return list(), dict()
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class CUDAGraphCache(CUDAGraphCacheABC):
|
| 211 |
+
def __init__(
|
| 212 |
+
self,
|
| 213 |
+
decoder: T2SDecoderABC,
|
| 214 |
+
device: torch.device = torch.device("cpu"),
|
| 215 |
+
dtype: torch.dtype = torch.float32,
|
| 216 |
+
) -> None:
|
| 217 |
+
super().__init__(decoder, device, dtype)
|
| 218 |
+
|
| 219 |
+
def release_graph(self, session: T2SSession):
|
| 220 |
+
if session.id != self.id:
|
| 221 |
+
self.assigned = False
|
| 222 |
+
else:
|
| 223 |
+
del session.graph, session.xy_pos_, session.xy_dec_, session.input_pos, session.kv_cache
|
| 224 |
+
|
| 225 |
+
def get_cache_graph(self, session: T2SSession):
|
| 226 |
+
assert self.graph
|
| 227 |
+
session.graph = self.graph
|
| 228 |
+
|
| 229 |
+
session.xy_pos_ = self.xy_pos
|
| 230 |
+
session.xy_dec_ = self.xy_dec
|
| 231 |
+
session.input_pos = self.input_pos.copy_(session.input_pos)
|
| 232 |
+
|
| 233 |
+
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
|
| 234 |
+
cache.sync_cache(cache_)
|
| 235 |
+
|
| 236 |
+
def capture_new_graph(self, session: T2SSession):
|
| 237 |
+
session.xy_pos_ = self.xy_pos.clone()
|
| 238 |
+
session.xy_dec_ = self.xy_dec.clone()
|
| 239 |
+
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
|
| 240 |
+
|
| 241 |
+
args, kwds = self.decoder.pre_forward(session)
|
| 242 |
+
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, kv_caches=self.kv_cache, *args, **kwds)
|
| 243 |
+
session.graph = graph
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class CUDAGraphRunner:
|
| 247 |
+
def __init__(
|
| 248 |
+
self,
|
| 249 |
+
decoder_model: T2SDecoderABC,
|
| 250 |
+
device: torch.device = torch.device("cpu"),
|
| 251 |
+
dtype: torch.dtype = torch.float32,
|
| 252 |
+
) -> None:
|
| 253 |
+
assert device.type == "cuda"
|
| 254 |
+
self.device = device
|
| 255 |
+
self.dtype = dtype
|
| 256 |
+
|
| 257 |
+
self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
|
| 258 |
+
|
| 259 |
+
self.graphcache = CUDAGraphCache(decoder_model, device, dtype)
|
| 260 |
+
|
| 261 |
+
def _handle_request(self, request: T2SRequest):
|
| 262 |
+
with self.device:
|
| 263 |
+
decoder = self.decoder_model
|
| 264 |
+
session = T2SSession(decoder, request, device=self.device, dtype=self.dtype)
|
| 265 |
+
|
| 266 |
+
t1 = 0.0
|
| 267 |
+
infer_speed = 0.0
|
| 268 |
+
|
| 269 |
+
torch_profiler = TorchProfiler(request.debug)
|
| 270 |
+
with torch_profiler.profiler():
|
| 271 |
+
for idx in tqdm(range(1500)):
|
| 272 |
+
if idx == 0:
|
| 273 |
+
session.kv_cache = decoder.init_cache(session.bsz)
|
| 274 |
+
xy_dec = decoder.h.prefill(session.xy_pos, session.attn_mask_nested, session.kv_cache)
|
| 275 |
+
xy_dec = torch.stack([t[[-1]] for t in xy_dec.unbind()])
|
| 276 |
+
else:
|
| 277 |
+
if request.use_cuda_graph and session.graph is None and torch.cuda.is_available():
|
| 278 |
+
self.graphcache.assign_graph(session)
|
| 279 |
+
|
| 280 |
+
with torch_profiler.record("AR"):
|
| 281 |
+
if session.graph:
|
| 282 |
+
session.xy_pos_.copy_(session.xy_pos)
|
| 283 |
+
session.graph.replay()
|
| 284 |
+
xy_dec = session.xy_dec_.clone()
|
| 285 |
+
else:
|
| 286 |
+
args, kwds = decoder.pre_forward(session)
|
| 287 |
+
xy_dec = decoder.h.forward(
|
| 288 |
+
session.input_pos,
|
| 289 |
+
session.xy_pos,
|
| 290 |
+
session.kv_cache,
|
| 291 |
+
*args,
|
| 292 |
+
**kwds,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
decoder.post_forward(idx, session)
|
| 296 |
+
logits = decoder.ar_predict_layer(xy_dec[:, -1])
|
| 297 |
+
session.input_pos.add_(1)
|
| 298 |
+
|
| 299 |
+
if idx == 0:
|
| 300 |
+
logits[:, -1] = float("-inf")
|
| 301 |
+
|
| 302 |
+
with torch_profiler.record("Sampling"):
|
| 303 |
+
samples = session.sampler.sample(
|
| 304 |
+
logits=logits,
|
| 305 |
+
previous_tokens=session.y,
|
| 306 |
+
top_k=request.top_k,
|
| 307 |
+
top_p=request.top_p,
|
| 308 |
+
repetition_penalty=request.repetition_penalty,
|
| 309 |
+
temperature=request.temperature,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
session.y = torch.cat([session.y, samples], dim=1)
|
| 313 |
+
|
| 314 |
+
with torch_profiler.record("EOS"):
|
| 315 |
+
argmax_token = torch.argmax(logits, dim=-1)
|
| 316 |
+
sample_token = samples.squeeze(1)
|
| 317 |
+
EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
|
| 318 |
+
|
| 319 |
+
newly_done_mask = EOS_mask & (~session.completed)
|
| 320 |
+
newly_done_indices = newly_done_mask.nonzero()
|
| 321 |
+
|
| 322 |
+
if newly_done_indices.numel() > 0:
|
| 323 |
+
session.y_results[newly_done_indices[0]] = session.y[
|
| 324 |
+
newly_done_indices[0], session.y_len : -1
|
| 325 |
+
].squeeze(0)
|
| 326 |
+
session.completed[newly_done_indices] = True
|
| 327 |
+
|
| 328 |
+
if torch.all(session.completed).item():
|
| 329 |
+
if session.y.size(1) == 0:
|
| 330 |
+
session.y = torch.cat([session.y, torch.zeros_like(samples)], dim=1)
|
| 331 |
+
tqdm.write("Bad Zero Prediction")
|
| 332 |
+
else:
|
| 333 |
+
tqdm.write(
|
| 334 |
+
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> \n{[i.size(0) for i in session.y_results].__str__().strip('[]')}"
|
| 335 |
+
)
|
| 336 |
+
tqdm.write(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
|
| 337 |
+
infer_speed = (idx - 1) / (time.perf_counter() - t1)
|
| 338 |
+
break
|
| 339 |
+
|
| 340 |
+
if (
|
| 341 |
+
request.early_stop_num != -1
|
| 342 |
+
and (session.y.size(1) - session.y_len) > request.early_stop_num
|
| 343 |
+
) or idx == 1499:
|
| 344 |
+
for i in range(session.bsz):
|
| 345 |
+
if not session.completed[i].item():
|
| 346 |
+
session.y_results[i] = session.y[i, session.y_len :]
|
| 347 |
+
session.completed[i] = True
|
| 348 |
+
break
|
| 349 |
+
|
| 350 |
+
with torch_profiler.record("NextPos"):
|
| 351 |
+
y_emb = decoder.ar_audio_embedding(session.y[:, -1:])
|
| 352 |
+
session.xy_pos = decoder.ar_audio_position.forward(session.input_pos - session.x_lens, y_emb)
|
| 353 |
+
|
| 354 |
+
if idx == 2:
|
| 355 |
+
torch_profiler.start()
|
| 356 |
+
t1 = time.perf_counter()
|
| 357 |
+
|
| 358 |
+
if idx == 51:
|
| 359 |
+
torch_profiler.end()
|
| 360 |
+
|
| 361 |
+
if idx % 100 == 0:
|
| 362 |
+
match session.device.type:
|
| 363 |
+
case "cuda":
|
| 364 |
+
torch.cuda.empty_cache()
|
| 365 |
+
case "mps":
|
| 366 |
+
torch.mps.empty_cache()
|
| 367 |
+
case "xpu":
|
| 368 |
+
torch.xpu.empty_cache()
|
| 369 |
+
case "mtia":
|
| 370 |
+
torch.mtia.empty_cache()
|
| 371 |
+
|
| 372 |
+
match session.device.type:
|
| 373 |
+
case "cuda":
|
| 374 |
+
torch.cuda.empty_cache()
|
| 375 |
+
case "mps":
|
| 376 |
+
torch.mps.empty_cache()
|
| 377 |
+
case "xpu":
|
| 378 |
+
torch.xpu.empty_cache()
|
| 379 |
+
case "mtia":
|
| 380 |
+
torch.mtia.empty_cache()
|
| 381 |
+
case "cpu":
|
| 382 |
+
gc.collect()
|
| 383 |
+
|
| 384 |
+
torch_profiler.end()
|
| 385 |
+
self.graphcache.release_graph(session)
|
| 386 |
+
return session.y_results[: request.valid_length], infer_speed
|
| 387 |
+
|
| 388 |
+
def generate(self, request: T2SRequest):
|
| 389 |
+
try:
|
| 390 |
+
result, infer_speed = self._handle_request(request)
|
| 391 |
+
t2s_result = T2SResult(result=result, infer_speed=infer_speed, status="Success")
|
| 392 |
+
except Exception as e:
|
| 393 |
+
t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
|
| 394 |
+
return t2s_result
|
| 395 |
+
|
| 396 |
+
@staticmethod
|
| 397 |
+
def load_decoder(weights_path: os.PathLike, implement: str = "flash_attn"):
|
| 398 |
+
print(f"Loading Text2Semantic Weights from {weights_path} with {implement.replace('_', ' ').title()} Implement")
|
| 399 |
+
module_path = f"AR.models.t2s_model_{implement.lower()}"
|
| 400 |
+
cls_name = "T2SDecoder"
|
| 401 |
+
mod = __import__(module_path, fromlist=[cls_name])
|
| 402 |
+
decoder_cls: T2SDecoderABC = getattr(mod, cls_name)
|
| 403 |
+
dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=False, mmap=True)
|
| 404 |
+
config = dict_s1["config"]
|
| 405 |
+
decoder: T2SDecoderABC = decoder_cls(config, max_batch_size=1)
|
| 406 |
+
state_dict = dict_s1["weight"]
|
| 407 |
+
decoder.load_state_dict(state_dict)
|
| 408 |
+
return decoder.eval()
|
configs/s1.yaml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train:
|
| 2 |
+
seed: 1234
|
| 3 |
+
epochs: 300
|
| 4 |
+
batch_size: 8
|
| 5 |
+
gradient_accumulation: 4
|
| 6 |
+
save_every_n_epoch: 1
|
| 7 |
+
precision: 16
|
| 8 |
+
gradient_clip: 1.0
|
| 9 |
+
optimizer:
|
| 10 |
+
lr: 0.01
|
| 11 |
+
lr_init: 0.00001
|
| 12 |
+
lr_end: 0.0001
|
| 13 |
+
warmup_steps: 2000
|
| 14 |
+
decay_steps: 40000
|
| 15 |
+
data:
|
| 16 |
+
max_eval_sample: 8
|
| 17 |
+
max_sec: 54
|
| 18 |
+
num_workers: 1
|
| 19 |
+
pad_val: 1024 # same with EOS in model
|
| 20 |
+
model:
|
| 21 |
+
vocab_size: 1025
|
| 22 |
+
phoneme_vocab_size: 512
|
| 23 |
+
embedding_dim: 512
|
| 24 |
+
hidden_dim: 512
|
| 25 |
+
head: 16
|
| 26 |
+
linear_units: 2048
|
| 27 |
+
n_layer: 12
|
| 28 |
+
dropout: 0
|
| 29 |
+
EOS: 1024
|
| 30 |
+
inference:
|
| 31 |
+
top_k: 5
|
configs/s1big.yaml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train:
|
| 2 |
+
seed: 1234
|
| 3 |
+
epochs: 300
|
| 4 |
+
batch_size: 8
|
| 5 |
+
gradient_accumulation: 4
|
| 6 |
+
save_every_n_epoch: 1
|
| 7 |
+
precision: 16-mixed
|
| 8 |
+
gradient_clip: 1.0
|
| 9 |
+
optimizer:
|
| 10 |
+
lr: 0.01
|
| 11 |
+
lr_init: 0.00001
|
| 12 |
+
lr_end: 0.0001
|
| 13 |
+
warmup_steps: 2000
|
| 14 |
+
decay_steps: 40000
|
| 15 |
+
data:
|
| 16 |
+
max_eval_sample: 8
|
| 17 |
+
max_sec: 54
|
| 18 |
+
num_workers: 1
|
| 19 |
+
pad_val: 1024 # same with EOS in model
|
| 20 |
+
model:
|
| 21 |
+
vocab_size: 1025
|
| 22 |
+
phoneme_vocab_size: 512
|
| 23 |
+
embedding_dim: 1024
|
| 24 |
+
hidden_dim: 1024
|
| 25 |
+
head: 16
|
| 26 |
+
linear_units: 2048
|
| 27 |
+
n_layer: 16
|
| 28 |
+
dropout: 0
|
| 29 |
+
EOS: 1024
|
| 30 |
+
inference:
|
| 31 |
+
top_k: 5
|
configs/s1big2.yaml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train:
|
| 2 |
+
seed: 1234
|
| 3 |
+
epochs: 300
|
| 4 |
+
batch_size: 12
|
| 5 |
+
gradient_accumulation: 4
|
| 6 |
+
save_every_n_epoch: 1
|
| 7 |
+
precision: 16-mixed
|
| 8 |
+
gradient_clip: 1.0
|
| 9 |
+
optimizer:
|
| 10 |
+
lr: 0.01
|
| 11 |
+
lr_init: 0.00001
|
| 12 |
+
lr_end: 0.0001
|
| 13 |
+
warmup_steps: 2000
|
| 14 |
+
decay_steps: 40000
|
| 15 |
+
data:
|
| 16 |
+
max_eval_sample: 8
|
| 17 |
+
max_sec: 54
|
| 18 |
+
num_workers: 1
|
| 19 |
+
pad_val: 1024 # same with EOS in model
|
| 20 |
+
model:
|
| 21 |
+
vocab_size: 1025
|
| 22 |
+
phoneme_vocab_size: 512
|
| 23 |
+
embedding_dim: 1024
|
| 24 |
+
hidden_dim: 1024
|
| 25 |
+
head: 16
|
| 26 |
+
linear_units: 2048
|
| 27 |
+
n_layer: 6
|
| 28 |
+
dropout: 0
|
| 29 |
+
EOS: 1024
|
| 30 |
+
inference:
|
| 31 |
+
top_k: 5
|
configs/s1longer-v2.yaml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train:
|
| 2 |
+
seed: 1234
|
| 3 |
+
epochs: 20
|
| 4 |
+
batch_size: 8
|
| 5 |
+
save_every_n_epoch: 1
|
| 6 |
+
precision: 16-mixed
|
| 7 |
+
gradient_clip: 1.0
|
| 8 |
+
optimizer:
|
| 9 |
+
lr: 0.01
|
| 10 |
+
lr_init: 0.00001
|
| 11 |
+
lr_end: 0.0001
|
| 12 |
+
warmup_steps: 2000
|
| 13 |
+
decay_steps: 40000
|
| 14 |
+
data:
|
| 15 |
+
max_eval_sample: 8
|
| 16 |
+
max_sec: 54
|
| 17 |
+
num_workers: 4
|
| 18 |
+
pad_val: 1024 # same with EOS in model
|
| 19 |
+
model:
|
| 20 |
+
vocab_size: 1025
|
| 21 |
+
phoneme_vocab_size: 732
|
| 22 |
+
embedding_dim: 512
|
| 23 |
+
hidden_dim: 512
|
| 24 |
+
head: 16
|
| 25 |
+
linear_units: 2048
|
| 26 |
+
n_layer: 24
|
| 27 |
+
dropout: 0
|
| 28 |
+
EOS: 1024
|
| 29 |
+
random_bert: 0
|
| 30 |
+
inference:
|
| 31 |
+
top_k: 15
|
configs/s1longer.yaml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train:
|
| 2 |
+
seed: 1234
|
| 3 |
+
epochs: 20
|
| 4 |
+
batch_size: 8
|
| 5 |
+
save_every_n_epoch: 1
|
| 6 |
+
precision: 16-mixed
|
| 7 |
+
gradient_clip: 1.0
|
| 8 |
+
optimizer:
|
| 9 |
+
lr: 0.01
|
| 10 |
+
lr_init: 0.00001
|
| 11 |
+
lr_end: 0.0001
|
| 12 |
+
warmup_steps: 2000
|
| 13 |
+
decay_steps: 40000
|
| 14 |
+
data:
|
| 15 |
+
max_eval_sample: 8
|
| 16 |
+
max_sec: 54
|
| 17 |
+
num_workers: 4
|
| 18 |
+
pad_val: 1024 # same with EOS in model
|
| 19 |
+
model:
|
| 20 |
+
vocab_size: 1025
|
| 21 |
+
phoneme_vocab_size: 512
|
| 22 |
+
embedding_dim: 512
|
| 23 |
+
hidden_dim: 512
|
| 24 |
+
head: 16
|
| 25 |
+
linear_units: 2048
|
| 26 |
+
n_layer: 24
|
| 27 |
+
dropout: 0
|
| 28 |
+
EOS: 1024
|
| 29 |
+
random_bert: 0
|
| 30 |
+
inference:
|
| 31 |
+
top_k: 5
|
configs/s1mq.yaml
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train:
|
| 2 |
+
seed: 1234
|
| 3 |
+
epochs: 100
|
| 4 |
+
batch_size: 6
|
| 5 |
+
gradient_accumulation: 4
|
| 6 |
+
save_every_n_epoch: 1
|
| 7 |
+
precision: 32
|
| 8 |
+
gradient_clip: 1.0
|
| 9 |
+
optimizer:
|
| 10 |
+
lr: 0.01
|
| 11 |
+
lr_init: 0.00001
|
| 12 |
+
lr_end: 0.0001
|
| 13 |
+
warmup_steps: 2000
|
| 14 |
+
decay_steps: 40000
|
| 15 |
+
data:
|
| 16 |
+
max_eval_sample: 8
|
| 17 |
+
max_sec: 40
|
| 18 |
+
num_workers: 1
|
| 19 |
+
pad_val: 1024 # same with EOS in model
|
| 20 |
+
model:
|
| 21 |
+
saving_path: "ckpt/"
|
| 22 |
+
resume_checkpoint: null
|
| 23 |
+
vocoder_config_path: "quantizer/new_ckpt/config.json"
|
| 24 |
+
vocoder_ckpt_path: "quantizer/new_ckpt/g_00600000"
|
| 25 |
+
datadir: "/home/liweiche/GigaSpeech/wavs"
|
| 26 |
+
metapath: "/home/liweiche/GigaSpeech/train2.json"
|
| 27 |
+
val_metapath: "/home/liweiche/GigaSpeech/dev2.json"
|
| 28 |
+
sampledir: "logs/"
|
| 29 |
+
pretrained_path: null
|
| 30 |
+
lr: 0.0001
|
| 31 |
+
batch_size: 200.0
|
| 32 |
+
train_bucket_size: 8192
|
| 33 |
+
training_step: 800000
|
| 34 |
+
optim_flat_percent: 0.0
|
| 35 |
+
warmup_step: 50
|
| 36 |
+
adam_beta1: 0.9
|
| 37 |
+
adam_beta2: 0.98
|
| 38 |
+
ffd_size: 3072
|
| 39 |
+
hidden_size: 768
|
| 40 |
+
enc_nlayers: 6
|
| 41 |
+
dec_nlayers: 6
|
| 42 |
+
nheads: 12
|
| 43 |
+
ar_layer: 4
|
| 44 |
+
ar_ffd_size: 1024
|
| 45 |
+
ar_hidden_size: 256
|
| 46 |
+
ar_nheads: 4
|
| 47 |
+
aligner_softmax_temp: 1.0
|
| 48 |
+
layer_norm_eps: 0.00001
|
| 49 |
+
speaker_embed_dropout: 0.05
|
| 50 |
+
label_smoothing: 0.0
|
| 51 |
+
val_check_interval: 5000
|
| 52 |
+
check_val_every_n_epoch: 1
|
| 53 |
+
precision: "fp16"
|
| 54 |
+
nworkers: 16
|
| 55 |
+
distributed: true
|
| 56 |
+
accelerator: "ddp"
|
| 57 |
+
version: null
|
| 58 |
+
accumulate_grad_batches: 1
|
| 59 |
+
use_repetition_token: true
|
| 60 |
+
use_repetition_gating: false
|
| 61 |
+
repetition_penalty: 1.0
|
| 62 |
+
sampling_temperature: 1.0
|
| 63 |
+
top_k: -1
|
| 64 |
+
min_top_k: 3
|
| 65 |
+
top_p: 0.8
|
| 66 |
+
sample_num: 4
|
| 67 |
+
length_penalty_max_length: 15000
|
| 68 |
+
length_penalty_max_prob: 0.95
|
| 69 |
+
max_input_length: 2048
|
| 70 |
+
max_output_length: 2000
|
| 71 |
+
sample_rate: 16000
|
| 72 |
+
n_codes: 1024
|
| 73 |
+
n_cluster_groups: 1
|
| 74 |
+
phone_context_window: 4
|
| 75 |
+
phoneset_size: 1000
|
| 76 |
+
inference:
|
| 77 |
+
top_k: 5
|
configs/s2.json
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"train": {
|
| 3 |
+
"log_interval": 100,
|
| 4 |
+
"eval_interval": 500,
|
| 5 |
+
"seed": 1234,
|
| 6 |
+
"epochs": 100,
|
| 7 |
+
"learning_rate": 0.0001,
|
| 8 |
+
"betas": [
|
| 9 |
+
0.8,
|
| 10 |
+
0.99
|
| 11 |
+
],
|
| 12 |
+
"eps": 1e-09,
|
| 13 |
+
"batch_size": 32,
|
| 14 |
+
"fp16_run": true,
|
| 15 |
+
"lr_decay": 0.999875,
|
| 16 |
+
"segment_size": 20480,
|
| 17 |
+
"init_lr_ratio": 1,
|
| 18 |
+
"warmup_epochs": 0,
|
| 19 |
+
"c_mel": 45,
|
| 20 |
+
"c_kl": 1.0,
|
| 21 |
+
"text_low_lr_rate": 0.4
|
| 22 |
+
},
|
| 23 |
+
"data": {
|
| 24 |
+
"max_wav_value": 32768.0,
|
| 25 |
+
"sampling_rate": 32000,
|
| 26 |
+
"filter_length": 2048,
|
| 27 |
+
"hop_length": 640,
|
| 28 |
+
"win_length": 2048,
|
| 29 |
+
"n_mel_channels": 128,
|
| 30 |
+
"mel_fmin": 0.0,
|
| 31 |
+
"mel_fmax": null,
|
| 32 |
+
"add_blank": true,
|
| 33 |
+
"n_speakers": 300,
|
| 34 |
+
"cleaned_text": true
|
| 35 |
+
},
|
| 36 |
+
"model": {
|
| 37 |
+
"inter_channels": 192,
|
| 38 |
+
"hidden_channels": 192,
|
| 39 |
+
"filter_channels": 768,
|
| 40 |
+
"n_heads": 2,
|
| 41 |
+
"n_layers": 6,
|
| 42 |
+
"kernel_size": 3,
|
| 43 |
+
"p_dropout": 0.1,
|
| 44 |
+
"resblock": "1",
|
| 45 |
+
"resblock_kernel_sizes": [
|
| 46 |
+
3,
|
| 47 |
+
7,
|
| 48 |
+
11
|
| 49 |
+
],
|
| 50 |
+
"resblock_dilation_sizes": [
|
| 51 |
+
[
|
| 52 |
+
1,
|
| 53 |
+
3,
|
| 54 |
+
5
|
| 55 |
+
],
|
| 56 |
+
[
|
| 57 |
+
1,
|
| 58 |
+
3,
|
| 59 |
+
5
|
| 60 |
+
],
|
| 61 |
+
[
|
| 62 |
+
1,
|
| 63 |
+
3,
|
| 64 |
+
5
|
| 65 |
+
]
|
| 66 |
+
],
|
| 67 |
+
"upsample_rates": [
|
| 68 |
+
10,
|
| 69 |
+
8,
|
| 70 |
+
2,
|
| 71 |
+
2,
|
| 72 |
+
2
|
| 73 |
+
],
|
| 74 |
+
"upsample_initial_channel": 512,
|
| 75 |
+
"upsample_kernel_sizes": [
|
| 76 |
+
16,
|
| 77 |
+
16,
|
| 78 |
+
8,
|
| 79 |
+
2,
|
| 80 |
+
2
|
| 81 |
+
],
|
| 82 |
+
"n_layers_q": 3,
|
| 83 |
+
"use_spectral_norm": false,
|
| 84 |
+
"gin_channels": 512,
|
| 85 |
+
"semantic_frame_rate": "25hz",
|
| 86 |
+
"freeze_quantizer": true
|
| 87 |
+
},
|
| 88 |
+
"s2_ckpt_dir": "logs/s2/big2k1",
|
| 89 |
+
"content_module": "cnhubert"
|
| 90 |
+
}
|
configs/train.yaml
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gpu:
|
| 2 |
+
n_card: 1
|
| 3 |
+
n_process_per_card: 2
|
| 4 |
+
io:
|
| 5 |
+
text_path: D:\RVC1006\GPT-SoVITS\GPT_SoVITS
|
| 6 |
+
save_every_n_epoch: 1
|
| 7 |
+
precision: 16-mixed
|
| 8 |
+
gradient_clip: 1.0
|
| 9 |
+
optimizer:
|
| 10 |
+
lr: 0.01
|
| 11 |
+
lr_init: 0.00001
|
| 12 |
+
lr_end: 0.0001
|
| 13 |
+
warmup_steps: 2000
|
| 14 |
+
decay_steps: 40000
|
| 15 |
+
data:
|
| 16 |
+
max_eval_sample: 8
|
| 17 |
+
max_sec: 54
|
| 18 |
+
num_workers: 1
|
| 19 |
+
pad_val: 1024 # same with EOS in model
|
| 20 |
+
model:
|
| 21 |
+
vocab_size: 1025
|
| 22 |
+
phoneme_vocab_size: 512
|
| 23 |
+
embedding_dim: 512
|
| 24 |
+
hidden_dim: 512
|
| 25 |
+
head: 16
|
| 26 |
+
linear_units: 2048
|
| 27 |
+
n_layer: 24
|
| 28 |
+
dropout: 0
|
| 29 |
+
EOS: 1024
|
| 30 |
+
random_bert: 0
|
| 31 |
+
inference:
|
| 32 |
+
top_k: 5
|
download.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
now_dir = os.getcwd()
|
| 3 |
+
sys.path.insert(0, now_dir)
|
| 4 |
+
from .text.g2pw import G2PWPinyin
|
| 5 |
+
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",v_to_u=False, neutral_tone_with_five=True)
|
eres2net/ERes2Net.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
|
| 6 |
+
ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
|
| 7 |
+
The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
|
| 8 |
+
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import math
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import pooling_layers as pooling_layers
|
| 17 |
+
from fusion import AFF
|
| 18 |
+
|
| 19 |
+
class ReLU(nn.Hardtanh):
|
| 20 |
+
|
| 21 |
+
def __init__(self, inplace=False):
|
| 22 |
+
super(ReLU, self).__init__(0, 20, inplace)
|
| 23 |
+
|
| 24 |
+
def __repr__(self):
|
| 25 |
+
inplace_str = 'inplace' if self.inplace else ''
|
| 26 |
+
return self.__class__.__name__ + ' (' \
|
| 27 |
+
+ inplace_str + ')'
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class BasicBlockERes2Net(nn.Module):
|
| 31 |
+
expansion = 2
|
| 32 |
+
|
| 33 |
+
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
|
| 34 |
+
super(BasicBlockERes2Net, self).__init__()
|
| 35 |
+
width = int(math.floor(planes*(baseWidth/64.0)))
|
| 36 |
+
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
|
| 37 |
+
self.bn1 = nn.BatchNorm2d(width*scale)
|
| 38 |
+
self.nums = scale
|
| 39 |
+
|
| 40 |
+
convs=[]
|
| 41 |
+
bns=[]
|
| 42 |
+
for i in range(self.nums):
|
| 43 |
+
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
| 44 |
+
bns.append(nn.BatchNorm2d(width))
|
| 45 |
+
self.convs = nn.ModuleList(convs)
|
| 46 |
+
self.bns = nn.ModuleList(bns)
|
| 47 |
+
self.relu = ReLU(inplace=True)
|
| 48 |
+
|
| 49 |
+
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
|
| 50 |
+
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
|
| 51 |
+
self.shortcut = nn.Sequential()
|
| 52 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
| 53 |
+
self.shortcut = nn.Sequential(
|
| 54 |
+
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
|
| 55 |
+
stride=stride, bias=False),
|
| 56 |
+
nn.BatchNorm2d(self.expansion * planes))
|
| 57 |
+
self.stride = stride
|
| 58 |
+
self.width = width
|
| 59 |
+
self.scale = scale
|
| 60 |
+
|
| 61 |
+
def forward(self, x):
|
| 62 |
+
residual = x
|
| 63 |
+
|
| 64 |
+
out = self.conv1(x)
|
| 65 |
+
out = self.bn1(out)
|
| 66 |
+
out = self.relu(out)
|
| 67 |
+
spx = torch.split(out,self.width,1)
|
| 68 |
+
for i in range(self.nums):
|
| 69 |
+
if i==0:
|
| 70 |
+
sp = spx[i]
|
| 71 |
+
else:
|
| 72 |
+
sp = sp + spx[i]
|
| 73 |
+
sp = self.convs[i](sp)
|
| 74 |
+
sp = self.relu(self.bns[i](sp))
|
| 75 |
+
if i==0:
|
| 76 |
+
out = sp
|
| 77 |
+
else:
|
| 78 |
+
out = torch.cat((out,sp),1)
|
| 79 |
+
|
| 80 |
+
out = self.conv3(out)
|
| 81 |
+
out = self.bn3(out)
|
| 82 |
+
|
| 83 |
+
residual = self.shortcut(x)
|
| 84 |
+
out += residual
|
| 85 |
+
out = self.relu(out)
|
| 86 |
+
|
| 87 |
+
return out
|
| 88 |
+
|
| 89 |
+
class BasicBlockERes2Net_diff_AFF(nn.Module):
|
| 90 |
+
expansion = 2
|
| 91 |
+
|
| 92 |
+
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
|
| 93 |
+
super(BasicBlockERes2Net_diff_AFF, self).__init__()
|
| 94 |
+
width = int(math.floor(planes*(baseWidth/64.0)))
|
| 95 |
+
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
|
| 96 |
+
self.bn1 = nn.BatchNorm2d(width*scale)
|
| 97 |
+
self.nums = scale
|
| 98 |
+
|
| 99 |
+
convs=[]
|
| 100 |
+
fuse_models=[]
|
| 101 |
+
bns=[]
|
| 102 |
+
for i in range(self.nums):
|
| 103 |
+
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
| 104 |
+
bns.append(nn.BatchNorm2d(width))
|
| 105 |
+
for j in range(self.nums - 1):
|
| 106 |
+
fuse_models.append(AFF(channels=width))
|
| 107 |
+
|
| 108 |
+
self.convs = nn.ModuleList(convs)
|
| 109 |
+
self.bns = nn.ModuleList(bns)
|
| 110 |
+
self.fuse_models = nn.ModuleList(fuse_models)
|
| 111 |
+
self.relu = ReLU(inplace=True)
|
| 112 |
+
|
| 113 |
+
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
|
| 114 |
+
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
|
| 115 |
+
self.shortcut = nn.Sequential()
|
| 116 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
| 117 |
+
self.shortcut = nn.Sequential(
|
| 118 |
+
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
|
| 119 |
+
stride=stride, bias=False),
|
| 120 |
+
nn.BatchNorm2d(self.expansion * planes))
|
| 121 |
+
self.stride = stride
|
| 122 |
+
self.width = width
|
| 123 |
+
self.scale = scale
|
| 124 |
+
|
| 125 |
+
def forward(self, x):
|
| 126 |
+
residual = x
|
| 127 |
+
|
| 128 |
+
out = self.conv1(x)
|
| 129 |
+
out = self.bn1(out)
|
| 130 |
+
out = self.relu(out)
|
| 131 |
+
spx = torch.split(out,self.width,1)
|
| 132 |
+
for i in range(self.nums):
|
| 133 |
+
if i==0:
|
| 134 |
+
sp = spx[i]
|
| 135 |
+
else:
|
| 136 |
+
sp = self.fuse_models[i-1](sp, spx[i])
|
| 137 |
+
|
| 138 |
+
sp = self.convs[i](sp)
|
| 139 |
+
sp = self.relu(self.bns[i](sp))
|
| 140 |
+
if i==0:
|
| 141 |
+
out = sp
|
| 142 |
+
else:
|
| 143 |
+
out = torch.cat((out,sp),1)
|
| 144 |
+
|
| 145 |
+
out = self.conv3(out)
|
| 146 |
+
out = self.bn3(out)
|
| 147 |
+
|
| 148 |
+
residual = self.shortcut(x)
|
| 149 |
+
out += residual
|
| 150 |
+
out = self.relu(out)
|
| 151 |
+
|
| 152 |
+
return out
|
| 153 |
+
|
| 154 |
+
class ERes2Net(nn.Module):
|
| 155 |
+
def __init__(self,
|
| 156 |
+
block=BasicBlockERes2Net,
|
| 157 |
+
block_fuse=BasicBlockERes2Net_diff_AFF,
|
| 158 |
+
num_blocks=[3, 4, 6, 3],
|
| 159 |
+
m_channels=32,
|
| 160 |
+
feat_dim=80,
|
| 161 |
+
embedding_size=192,
|
| 162 |
+
pooling_func='TSTP',
|
| 163 |
+
two_emb_layer=False):
|
| 164 |
+
super(ERes2Net, self).__init__()
|
| 165 |
+
self.in_planes = m_channels
|
| 166 |
+
self.feat_dim = feat_dim
|
| 167 |
+
self.embedding_size = embedding_size
|
| 168 |
+
self.stats_dim = int(feat_dim / 8) * m_channels * 8
|
| 169 |
+
self.two_emb_layer = two_emb_layer
|
| 170 |
+
|
| 171 |
+
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
| 172 |
+
self.bn1 = nn.BatchNorm2d(m_channels)
|
| 173 |
+
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
|
| 174 |
+
self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
|
| 175 |
+
self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
|
| 176 |
+
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
|
| 177 |
+
|
| 178 |
+
# Downsampling module for each layer
|
| 179 |
+
self.layer1_downsample = nn.Conv2d(m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1, bias=False)
|
| 180 |
+
self.layer2_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False)
|
| 181 |
+
self.layer3_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False)
|
| 182 |
+
|
| 183 |
+
# Bottom-up fusion module
|
| 184 |
+
self.fuse_mode12 = AFF(channels=m_channels * 4)
|
| 185 |
+
self.fuse_mode123 = AFF(channels=m_channels * 8)
|
| 186 |
+
self.fuse_mode1234 = AFF(channels=m_channels * 16)
|
| 187 |
+
|
| 188 |
+
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
|
| 189 |
+
self.pool = getattr(pooling_layers, pooling_func)(
|
| 190 |
+
in_dim=self.stats_dim * block.expansion)
|
| 191 |
+
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
|
| 192 |
+
embedding_size)
|
| 193 |
+
if self.two_emb_layer:
|
| 194 |
+
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
| 195 |
+
self.seg_2 = nn.Linear(embedding_size, embedding_size)
|
| 196 |
+
else:
|
| 197 |
+
self.seg_bn_1 = nn.Identity()
|
| 198 |
+
self.seg_2 = nn.Identity()
|
| 199 |
+
|
| 200 |
+
def _make_layer(self, block, planes, num_blocks, stride):
|
| 201 |
+
strides = [stride] + [1] * (num_blocks - 1)
|
| 202 |
+
layers = []
|
| 203 |
+
for stride in strides:
|
| 204 |
+
layers.append(block(self.in_planes, planes, stride))
|
| 205 |
+
self.in_planes = planes * block.expansion
|
| 206 |
+
return nn.Sequential(*layers)
|
| 207 |
+
|
| 208 |
+
def forward(self, x):
|
| 209 |
+
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
| 210 |
+
x = x.unsqueeze_(1)
|
| 211 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 212 |
+
out1 = self.layer1(out)
|
| 213 |
+
out2 = self.layer2(out1)
|
| 214 |
+
out1_downsample = self.layer1_downsample(out1)
|
| 215 |
+
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
| 216 |
+
out3 = self.layer3(out2)
|
| 217 |
+
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
| 218 |
+
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
| 219 |
+
out4 = self.layer4(out3)
|
| 220 |
+
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
| 221 |
+
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
|
| 222 |
+
stats = self.pool(fuse_out1234)
|
| 223 |
+
|
| 224 |
+
embed_a = self.seg_1(stats)
|
| 225 |
+
if self.two_emb_layer:
|
| 226 |
+
out = F.relu(embed_a)
|
| 227 |
+
out = self.seg_bn_1(out)
|
| 228 |
+
embed_b = self.seg_2(out)
|
| 229 |
+
return embed_b
|
| 230 |
+
else:
|
| 231 |
+
return embed_a
|
| 232 |
+
|
| 233 |
+
def forward3(self, x):
|
| 234 |
+
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
| 235 |
+
x = x.unsqueeze_(1)
|
| 236 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 237 |
+
out1 = self.layer1(out)
|
| 238 |
+
out2 = self.layer2(out1)
|
| 239 |
+
out1_downsample = self.layer1_downsample(out1)
|
| 240 |
+
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
| 241 |
+
out3 = self.layer3(out2)
|
| 242 |
+
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
| 243 |
+
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
| 244 |
+
out4 = self.layer4(out3)
|
| 245 |
+
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
| 246 |
+
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1,end_dim=2).mean(-1)
|
| 247 |
+
return fuse_out1234
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
if __name__ == '__main__':
|
| 251 |
+
|
| 252 |
+
x = torch.zeros(10, 300, 80)
|
| 253 |
+
model = ERes2Net(feat_dim=80, embedding_size=192, pooling_func='TSTP')
|
| 254 |
+
model.eval()
|
| 255 |
+
out = model(x)
|
| 256 |
+
print(out.shape) # torch.Size([10, 192])
|
| 257 |
+
|
| 258 |
+
num_params = sum(param.numel() for param in model.parameters())
|
| 259 |
+
print("{} M".format(num_params / 1e6)) # 6.61M
|
| 260 |
+
|
eres2net/ERes2NetV2.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
To further improve the short-duration feature extraction capability of ERes2Net, we expand the channel dimension
|
| 6 |
+
within each stage. However, this modification also increases the number of model parameters and computational complexity.
|
| 7 |
+
To alleviate this problem, we propose an improved ERes2NetV2 by pruning redundant structures, ultimately reducing
|
| 8 |
+
both the model parameters and its computational cost.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import math
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
import pooling_layers as pooling_layers
|
| 18 |
+
from fusion import AFF
|
| 19 |
+
|
| 20 |
+
class ReLU(nn.Hardtanh):
|
| 21 |
+
|
| 22 |
+
def __init__(self, inplace=False):
|
| 23 |
+
super(ReLU, self).__init__(0, 20, inplace)
|
| 24 |
+
|
| 25 |
+
def __repr__(self):
|
| 26 |
+
inplace_str = 'inplace' if self.inplace else ''
|
| 27 |
+
return self.__class__.__name__ + ' (' \
|
| 28 |
+
+ inplace_str + ')'
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class BasicBlockERes2NetV2(nn.Module):
|
| 32 |
+
|
| 33 |
+
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
|
| 34 |
+
super(BasicBlockERes2NetV2, self).__init__()
|
| 35 |
+
width = int(math.floor(planes*(baseWidth/64.0)))
|
| 36 |
+
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
|
| 37 |
+
self.bn1 = nn.BatchNorm2d(width*scale)
|
| 38 |
+
self.nums = scale
|
| 39 |
+
self.expansion = expansion
|
| 40 |
+
|
| 41 |
+
convs=[]
|
| 42 |
+
bns=[]
|
| 43 |
+
for i in range(self.nums):
|
| 44 |
+
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
| 45 |
+
bns.append(nn.BatchNorm2d(width))
|
| 46 |
+
self.convs = nn.ModuleList(convs)
|
| 47 |
+
self.bns = nn.ModuleList(bns)
|
| 48 |
+
self.relu = ReLU(inplace=True)
|
| 49 |
+
|
| 50 |
+
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
|
| 51 |
+
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
|
| 52 |
+
self.shortcut = nn.Sequential()
|
| 53 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
| 54 |
+
self.shortcut = nn.Sequential(
|
| 55 |
+
nn.Conv2d(in_planes,
|
| 56 |
+
self.expansion * planes,
|
| 57 |
+
kernel_size=1,
|
| 58 |
+
stride=stride,
|
| 59 |
+
bias=False),
|
| 60 |
+
nn.BatchNorm2d(self.expansion * planes))
|
| 61 |
+
self.stride = stride
|
| 62 |
+
self.width = width
|
| 63 |
+
self.scale = scale
|
| 64 |
+
|
| 65 |
+
def forward(self, x):
|
| 66 |
+
residual = x
|
| 67 |
+
|
| 68 |
+
out = self.conv1(x)
|
| 69 |
+
out = self.bn1(out)
|
| 70 |
+
out = self.relu(out)
|
| 71 |
+
spx = torch.split(out,self.width,1)
|
| 72 |
+
for i in range(self.nums):
|
| 73 |
+
if i==0:
|
| 74 |
+
sp = spx[i]
|
| 75 |
+
else:
|
| 76 |
+
sp = sp + spx[i]
|
| 77 |
+
sp = self.convs[i](sp)
|
| 78 |
+
sp = self.relu(self.bns[i](sp))
|
| 79 |
+
if i==0:
|
| 80 |
+
out = sp
|
| 81 |
+
else:
|
| 82 |
+
out = torch.cat((out,sp),1)
|
| 83 |
+
|
| 84 |
+
out = self.conv3(out)
|
| 85 |
+
out = self.bn3(out)
|
| 86 |
+
|
| 87 |
+
residual = self.shortcut(x)
|
| 88 |
+
out += residual
|
| 89 |
+
out = self.relu(out)
|
| 90 |
+
|
| 91 |
+
return out
|
| 92 |
+
|
| 93 |
+
class BasicBlockERes2NetV2AFF(nn.Module):
|
| 94 |
+
|
| 95 |
+
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
|
| 96 |
+
super(BasicBlockERes2NetV2AFF, self).__init__()
|
| 97 |
+
width = int(math.floor(planes*(baseWidth/64.0)))
|
| 98 |
+
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
|
| 99 |
+
self.bn1 = nn.BatchNorm2d(width*scale)
|
| 100 |
+
self.nums = scale
|
| 101 |
+
self.expansion = expansion
|
| 102 |
+
|
| 103 |
+
convs=[]
|
| 104 |
+
fuse_models=[]
|
| 105 |
+
bns=[]
|
| 106 |
+
for i in range(self.nums):
|
| 107 |
+
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
| 108 |
+
bns.append(nn.BatchNorm2d(width))
|
| 109 |
+
for j in range(self.nums - 1):
|
| 110 |
+
fuse_models.append(AFF(channels=width, r=4))
|
| 111 |
+
|
| 112 |
+
self.convs = nn.ModuleList(convs)
|
| 113 |
+
self.bns = nn.ModuleList(bns)
|
| 114 |
+
self.fuse_models = nn.ModuleList(fuse_models)
|
| 115 |
+
self.relu = ReLU(inplace=True)
|
| 116 |
+
|
| 117 |
+
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
|
| 118 |
+
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
|
| 119 |
+
self.shortcut = nn.Sequential()
|
| 120 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
| 121 |
+
self.shortcut = nn.Sequential(
|
| 122 |
+
nn.Conv2d(in_planes,
|
| 123 |
+
self.expansion * planes,
|
| 124 |
+
kernel_size=1,
|
| 125 |
+
stride=stride,
|
| 126 |
+
bias=False),
|
| 127 |
+
nn.BatchNorm2d(self.expansion * planes))
|
| 128 |
+
self.stride = stride
|
| 129 |
+
self.width = width
|
| 130 |
+
self.scale = scale
|
| 131 |
+
|
| 132 |
+
def forward(self, x):
|
| 133 |
+
residual = x
|
| 134 |
+
|
| 135 |
+
out = self.conv1(x)
|
| 136 |
+
out = self.bn1(out)
|
| 137 |
+
out = self.relu(out)
|
| 138 |
+
spx = torch.split(out,self.width,1)
|
| 139 |
+
for i in range(self.nums):
|
| 140 |
+
if i==0:
|
| 141 |
+
sp = spx[i]
|
| 142 |
+
else:
|
| 143 |
+
sp = self.fuse_models[i-1](sp, spx[i])
|
| 144 |
+
|
| 145 |
+
sp = self.convs[i](sp)
|
| 146 |
+
sp = self.relu(self.bns[i](sp))
|
| 147 |
+
if i==0:
|
| 148 |
+
out = sp
|
| 149 |
+
else:
|
| 150 |
+
out = torch.cat((out,sp),1)
|
| 151 |
+
|
| 152 |
+
out = self.conv3(out)
|
| 153 |
+
out = self.bn3(out)
|
| 154 |
+
|
| 155 |
+
residual = self.shortcut(x)
|
| 156 |
+
out += residual
|
| 157 |
+
out = self.relu(out)
|
| 158 |
+
|
| 159 |
+
return out
|
| 160 |
+
|
| 161 |
+
class ERes2NetV2(nn.Module):
|
| 162 |
+
def __init__(self,
|
| 163 |
+
block=BasicBlockERes2NetV2,
|
| 164 |
+
block_fuse=BasicBlockERes2NetV2AFF,
|
| 165 |
+
num_blocks=[3, 4, 6, 3],
|
| 166 |
+
m_channels=64,
|
| 167 |
+
feat_dim=80,
|
| 168 |
+
embedding_size=192,
|
| 169 |
+
baseWidth=26,
|
| 170 |
+
scale=2,
|
| 171 |
+
expansion=2,
|
| 172 |
+
pooling_func='TSTP',
|
| 173 |
+
two_emb_layer=False):
|
| 174 |
+
super(ERes2NetV2, self).__init__()
|
| 175 |
+
self.in_planes = m_channels
|
| 176 |
+
self.feat_dim = feat_dim
|
| 177 |
+
self.embedding_size = embedding_size
|
| 178 |
+
self.stats_dim = int(feat_dim / 8) * m_channels * 8
|
| 179 |
+
self.two_emb_layer = two_emb_layer
|
| 180 |
+
self.baseWidth = baseWidth
|
| 181 |
+
self.scale = scale
|
| 182 |
+
self.expansion = expansion
|
| 183 |
+
|
| 184 |
+
self.conv1 = nn.Conv2d(1,
|
| 185 |
+
m_channels,
|
| 186 |
+
kernel_size=3,
|
| 187 |
+
stride=1,
|
| 188 |
+
padding=1,
|
| 189 |
+
bias=False)
|
| 190 |
+
self.bn1 = nn.BatchNorm2d(m_channels)
|
| 191 |
+
self.layer1 = self._make_layer(block,
|
| 192 |
+
m_channels,
|
| 193 |
+
num_blocks[0],
|
| 194 |
+
stride=1)
|
| 195 |
+
self.layer2 = self._make_layer(block,
|
| 196 |
+
m_channels * 2,
|
| 197 |
+
num_blocks[1],
|
| 198 |
+
stride=2)
|
| 199 |
+
self.layer3 = self._make_layer(block_fuse,
|
| 200 |
+
m_channels * 4,
|
| 201 |
+
num_blocks[2],
|
| 202 |
+
stride=2)
|
| 203 |
+
self.layer4 = self._make_layer(block_fuse,
|
| 204 |
+
m_channels * 8,
|
| 205 |
+
num_blocks[3],
|
| 206 |
+
stride=2)
|
| 207 |
+
|
| 208 |
+
# Downsampling module
|
| 209 |
+
self.layer3_ds = nn.Conv2d(m_channels * 4 * self.expansion, m_channels * 8 * self.expansion, kernel_size=3, \
|
| 210 |
+
padding=1, stride=2, bias=False)
|
| 211 |
+
|
| 212 |
+
# Bottom-up fusion module
|
| 213 |
+
self.fuse34 = AFF(channels=m_channels * 8 * self.expansion, r=4)
|
| 214 |
+
|
| 215 |
+
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
|
| 216 |
+
self.pool = getattr(pooling_layers, pooling_func)(
|
| 217 |
+
in_dim=self.stats_dim * self.expansion)
|
| 218 |
+
self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats,
|
| 219 |
+
embedding_size)
|
| 220 |
+
if self.two_emb_layer:
|
| 221 |
+
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
| 222 |
+
self.seg_2 = nn.Linear(embedding_size, embedding_size)
|
| 223 |
+
else:
|
| 224 |
+
self.seg_bn_1 = nn.Identity()
|
| 225 |
+
self.seg_2 = nn.Identity()
|
| 226 |
+
|
| 227 |
+
def _make_layer(self, block, planes, num_blocks, stride):
|
| 228 |
+
strides = [stride] + [1] * (num_blocks - 1)
|
| 229 |
+
layers = []
|
| 230 |
+
for stride in strides:
|
| 231 |
+
layers.append(block(self.in_planes, planes, stride, baseWidth=self.baseWidth, scale=self.scale, expansion=self.expansion))
|
| 232 |
+
self.in_planes = planes * self.expansion
|
| 233 |
+
return nn.Sequential(*layers)
|
| 234 |
+
|
| 235 |
+
def forward(self, x):
|
| 236 |
+
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
| 237 |
+
x = x.unsqueeze_(1)
|
| 238 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 239 |
+
out1 = self.layer1(out)
|
| 240 |
+
out2 = self.layer2(out1)
|
| 241 |
+
out3 = self.layer3(out2)
|
| 242 |
+
out4 = self.layer4(out3)
|
| 243 |
+
out3_ds = self.layer3_ds(out3)
|
| 244 |
+
fuse_out34 = self.fuse34(out4, out3_ds)
|
| 245 |
+
stats = self.pool(fuse_out34)
|
| 246 |
+
|
| 247 |
+
embed_a = self.seg_1(stats)
|
| 248 |
+
if self.two_emb_layer:
|
| 249 |
+
out = F.relu(embed_a)
|
| 250 |
+
out = self.seg_bn_1(out)
|
| 251 |
+
embed_b = self.seg_2(out)
|
| 252 |
+
return embed_b
|
| 253 |
+
else:
|
| 254 |
+
return embed_a
|
| 255 |
+
|
| 256 |
+
def forward3(self, x):
|
| 257 |
+
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
| 258 |
+
x = x.unsqueeze_(1)
|
| 259 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 260 |
+
out1 = self.layer1(out)
|
| 261 |
+
out2 = self.layer2(out1)
|
| 262 |
+
out3 = self.layer3(out2)
|
| 263 |
+
out4 = self.layer4(out3)
|
| 264 |
+
out3_ds = self.layer3_ds(out3)
|
| 265 |
+
fuse_out34 = self.fuse34(out4, out3_ds)
|
| 266 |
+
# print(111111111,fuse_out34.shape)#111111111 torch.Size([16, 2048, 10, 72])
|
| 267 |
+
return fuse_out34.flatten(start_dim=1,end_dim=2).mean(-1)
|
| 268 |
+
# stats = self.pool(fuse_out34)
|
| 269 |
+
#
|
| 270 |
+
# embed_a = self.seg_1(stats)
|
| 271 |
+
# if self.two_emb_layer:
|
| 272 |
+
# out = F.relu(embed_a)
|
| 273 |
+
# out = self.seg_bn_1(out)
|
| 274 |
+
# embed_b = self.seg_2(out)
|
| 275 |
+
# return embed_b
|
| 276 |
+
# else:
|
| 277 |
+
# return embed_a
|
| 278 |
+
|
| 279 |
+
if __name__ == '__main__':
|
| 280 |
+
|
| 281 |
+
x = torch.randn(1, 300, 80)
|
| 282 |
+
model = ERes2NetV2(feat_dim=80, embedding_size=192, m_channels=64, baseWidth=26, scale=2, expansion=2)
|
| 283 |
+
model.eval()
|
| 284 |
+
y = model(x)
|
| 285 |
+
print(y.size())
|
| 286 |
+
macs, num_params = profile(model, inputs=(x, ))
|
| 287 |
+
print("Params: {} M".format(num_params / 1e6)) # 17.86 M
|
| 288 |
+
print("MACs: {} G".format(macs / 1e9)) # 12.69 G
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
|
eres2net/ERes2Net_huge.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 3 |
+
|
| 4 |
+
""" Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
|
| 5 |
+
ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
|
| 6 |
+
The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
|
| 7 |
+
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
|
| 8 |
+
ERes2Net-huge is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
|
| 9 |
+
recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
|
| 10 |
+
"""
|
| 11 |
+
import pdb
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import math
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
import pooling_layers as pooling_layers
|
| 18 |
+
from fusion import AFF
|
| 19 |
+
|
| 20 |
+
class ReLU(nn.Hardtanh):
|
| 21 |
+
|
| 22 |
+
def __init__(self, inplace=False):
|
| 23 |
+
super(ReLU, self).__init__(0, 20, inplace)
|
| 24 |
+
|
| 25 |
+
def __repr__(self):
|
| 26 |
+
inplace_str = 'inplace' if self.inplace else ''
|
| 27 |
+
return self.__class__.__name__ + ' (' \
|
| 28 |
+
+ inplace_str + ')'
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class BasicBlockERes2Net(nn.Module):
|
| 32 |
+
expansion = 4
|
| 33 |
+
|
| 34 |
+
def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
|
| 35 |
+
super(BasicBlockERes2Net, self).__init__()
|
| 36 |
+
width = int(math.floor(planes*(baseWidth/64.0)))
|
| 37 |
+
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
|
| 38 |
+
self.bn1 = nn.BatchNorm2d(width*scale)
|
| 39 |
+
self.nums = scale
|
| 40 |
+
|
| 41 |
+
convs=[]
|
| 42 |
+
bns=[]
|
| 43 |
+
for i in range(self.nums):
|
| 44 |
+
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
| 45 |
+
bns.append(nn.BatchNorm2d(width))
|
| 46 |
+
self.convs = nn.ModuleList(convs)
|
| 47 |
+
self.bns = nn.ModuleList(bns)
|
| 48 |
+
self.relu = ReLU(inplace=True)
|
| 49 |
+
|
| 50 |
+
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
|
| 51 |
+
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
|
| 52 |
+
self.shortcut = nn.Sequential()
|
| 53 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
| 54 |
+
self.shortcut = nn.Sequential(
|
| 55 |
+
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
| 56 |
+
nn.BatchNorm2d(self.expansion * planes))
|
| 57 |
+
self.stride = stride
|
| 58 |
+
self.width = width
|
| 59 |
+
self.scale = scale
|
| 60 |
+
|
| 61 |
+
def forward(self, x):
|
| 62 |
+
residual = x
|
| 63 |
+
|
| 64 |
+
out = self.conv1(x)
|
| 65 |
+
out = self.bn1(out)
|
| 66 |
+
out = self.relu(out)
|
| 67 |
+
spx = torch.split(out,self.width,1)
|
| 68 |
+
for i in range(self.nums):
|
| 69 |
+
if i==0:
|
| 70 |
+
sp = spx[i]
|
| 71 |
+
else:
|
| 72 |
+
sp = sp + spx[i]
|
| 73 |
+
sp = self.convs[i](sp)
|
| 74 |
+
sp = self.relu(self.bns[i](sp))
|
| 75 |
+
if i==0:
|
| 76 |
+
out = sp
|
| 77 |
+
else:
|
| 78 |
+
out = torch.cat((out,sp),1)
|
| 79 |
+
|
| 80 |
+
out = self.conv3(out)
|
| 81 |
+
out = self.bn3(out)
|
| 82 |
+
|
| 83 |
+
residual = self.shortcut(x)
|
| 84 |
+
out += residual
|
| 85 |
+
out = self.relu(out)
|
| 86 |
+
|
| 87 |
+
return out
|
| 88 |
+
|
| 89 |
+
class BasicBlockERes2Net_diff_AFF(nn.Module):
|
| 90 |
+
expansion = 4
|
| 91 |
+
|
| 92 |
+
def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
|
| 93 |
+
super(BasicBlockERes2Net_diff_AFF, self).__init__()
|
| 94 |
+
width = int(math.floor(planes*(baseWidth/64.0)))
|
| 95 |
+
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
|
| 96 |
+
self.bn1 = nn.BatchNorm2d(width*scale)
|
| 97 |
+
self.nums = scale
|
| 98 |
+
|
| 99 |
+
convs=[]
|
| 100 |
+
fuse_models=[]
|
| 101 |
+
bns=[]
|
| 102 |
+
for i in range(self.nums):
|
| 103 |
+
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
| 104 |
+
bns.append(nn.BatchNorm2d(width))
|
| 105 |
+
for j in range(self.nums - 1):
|
| 106 |
+
fuse_models.append(AFF(channels=width))
|
| 107 |
+
|
| 108 |
+
self.convs = nn.ModuleList(convs)
|
| 109 |
+
self.bns = nn.ModuleList(bns)
|
| 110 |
+
self.fuse_models = nn.ModuleList(fuse_models)
|
| 111 |
+
self.relu = ReLU(inplace=True)
|
| 112 |
+
|
| 113 |
+
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
|
| 114 |
+
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
|
| 115 |
+
self.shortcut = nn.Sequential()
|
| 116 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
| 117 |
+
self.shortcut = nn.Sequential(
|
| 118 |
+
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
| 119 |
+
nn.BatchNorm2d(self.expansion * planes))
|
| 120 |
+
self.stride = stride
|
| 121 |
+
self.width = width
|
| 122 |
+
self.scale = scale
|
| 123 |
+
|
| 124 |
+
def forward(self, x):
|
| 125 |
+
residual = x
|
| 126 |
+
|
| 127 |
+
out = self.conv1(x)
|
| 128 |
+
out = self.bn1(out)
|
| 129 |
+
out = self.relu(out)
|
| 130 |
+
spx = torch.split(out,self.width,1)
|
| 131 |
+
for i in range(self.nums):
|
| 132 |
+
if i==0:
|
| 133 |
+
sp = spx[i]
|
| 134 |
+
else:
|
| 135 |
+
sp = self.fuse_models[i-1](sp, spx[i])
|
| 136 |
+
|
| 137 |
+
sp = self.convs[i](sp)
|
| 138 |
+
sp = self.relu(self.bns[i](sp))
|
| 139 |
+
if i==0:
|
| 140 |
+
out = sp
|
| 141 |
+
else:
|
| 142 |
+
out = torch.cat((out,sp),1)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
out = self.conv3(out)
|
| 146 |
+
out = self.bn3(out)
|
| 147 |
+
|
| 148 |
+
residual = self.shortcut(x)
|
| 149 |
+
out += residual
|
| 150 |
+
out = self.relu(out)
|
| 151 |
+
|
| 152 |
+
return out
|
| 153 |
+
|
| 154 |
+
class ERes2Net(nn.Module):
|
| 155 |
+
def __init__(self,
|
| 156 |
+
block=BasicBlockERes2Net,
|
| 157 |
+
block_fuse=BasicBlockERes2Net_diff_AFF,
|
| 158 |
+
num_blocks=[3, 4, 6, 3],
|
| 159 |
+
m_channels=64,
|
| 160 |
+
feat_dim=80,
|
| 161 |
+
embedding_size=192,
|
| 162 |
+
pooling_func='TSTP',
|
| 163 |
+
two_emb_layer=False):
|
| 164 |
+
super(ERes2Net, self).__init__()
|
| 165 |
+
self.in_planes = m_channels
|
| 166 |
+
self.feat_dim = feat_dim
|
| 167 |
+
self.embedding_size = embedding_size
|
| 168 |
+
self.stats_dim = int(feat_dim / 8) * m_channels * 8
|
| 169 |
+
self.two_emb_layer = two_emb_layer
|
| 170 |
+
|
| 171 |
+
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
| 172 |
+
self.bn1 = nn.BatchNorm2d(m_channels)
|
| 173 |
+
|
| 174 |
+
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
|
| 175 |
+
self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
|
| 176 |
+
self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
|
| 177 |
+
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
|
| 178 |
+
|
| 179 |
+
self.layer1_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False)
|
| 180 |
+
self.layer2_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False)
|
| 181 |
+
self.layer3_downsample = nn.Conv2d(m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2, bias=False)
|
| 182 |
+
|
| 183 |
+
self.fuse_mode12 = AFF(channels=m_channels * 8)
|
| 184 |
+
self.fuse_mode123 = AFF(channels=m_channels * 16)
|
| 185 |
+
self.fuse_mode1234 = AFF(channels=m_channels * 32)
|
| 186 |
+
|
| 187 |
+
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
|
| 188 |
+
self.pool = getattr(pooling_layers, pooling_func)(
|
| 189 |
+
in_dim=self.stats_dim * block.expansion)
|
| 190 |
+
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
|
| 191 |
+
if self.two_emb_layer:
|
| 192 |
+
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
| 193 |
+
self.seg_2 = nn.Linear(embedding_size, embedding_size)
|
| 194 |
+
else:
|
| 195 |
+
self.seg_bn_1 = nn.Identity()
|
| 196 |
+
self.seg_2 = nn.Identity()
|
| 197 |
+
|
| 198 |
+
def _make_layer(self, block, planes, num_blocks, stride):
|
| 199 |
+
strides = [stride] + [1] * (num_blocks - 1)
|
| 200 |
+
layers = []
|
| 201 |
+
for stride in strides:
|
| 202 |
+
layers.append(block(self.in_planes, planes, stride))
|
| 203 |
+
self.in_planes = planes * block.expansion
|
| 204 |
+
return nn.Sequential(*layers)
|
| 205 |
+
|
| 206 |
+
def forward(self, x):
|
| 207 |
+
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
| 208 |
+
|
| 209 |
+
x = x.unsqueeze_(1)
|
| 210 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 211 |
+
out1 = self.layer1(out)
|
| 212 |
+
out2 = self.layer2(out1)
|
| 213 |
+
out1_downsample = self.layer1_downsample(out1)
|
| 214 |
+
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
| 215 |
+
out3 = self.layer3(out2)
|
| 216 |
+
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
| 217 |
+
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
| 218 |
+
out4 = self.layer4(out3)
|
| 219 |
+
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
| 220 |
+
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
|
| 221 |
+
stats = self.pool(fuse_out1234)
|
| 222 |
+
|
| 223 |
+
embed_a = self.seg_1(stats)
|
| 224 |
+
if self.two_emb_layer:
|
| 225 |
+
out = F.relu(embed_a)
|
| 226 |
+
out = self.seg_bn_1(out)
|
| 227 |
+
embed_b = self.seg_2(out)
|
| 228 |
+
return embed_b
|
| 229 |
+
else:
|
| 230 |
+
return embed_a
|
| 231 |
+
|
| 232 |
+
def forward2(self, x,if_mean):
|
| 233 |
+
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
| 234 |
+
|
| 235 |
+
x = x.unsqueeze_(1)
|
| 236 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 237 |
+
out1 = self.layer1(out)
|
| 238 |
+
out2 = self.layer2(out1)
|
| 239 |
+
out1_downsample = self.layer1_downsample(out1)
|
| 240 |
+
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
| 241 |
+
out3 = self.layer3(out2)
|
| 242 |
+
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
| 243 |
+
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
| 244 |
+
out4 = self.layer4(out3)
|
| 245 |
+
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
| 246 |
+
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1,end_dim=2)#bs,20480,T
|
| 247 |
+
if(if_mean==False):
|
| 248 |
+
mean=fuse_out1234[0].transpose(1,0)#(T,20480),bs=T
|
| 249 |
+
else:
|
| 250 |
+
mean = fuse_out1234.mean(2)#bs,20480
|
| 251 |
+
mean_std=torch.cat([mean,torch.zeros_like(mean)],1)
|
| 252 |
+
return self.seg_1(mean_std)#(T,192)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# stats = self.pool(fuse_out1234)
|
| 256 |
+
# if self.two_emb_layer:
|
| 257 |
+
# out = F.relu(embed_a)
|
| 258 |
+
# out = self.seg_bn_1(out)
|
| 259 |
+
# embed_b = self.seg_2(out)
|
| 260 |
+
# return embed_b
|
| 261 |
+
# else:
|
| 262 |
+
# return embed_a
|
| 263 |
+
|
| 264 |
+
def forward3(self, x):
|
| 265 |
+
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
| 266 |
+
|
| 267 |
+
x = x.unsqueeze_(1)
|
| 268 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 269 |
+
out1 = self.layer1(out)
|
| 270 |
+
out2 = self.layer2(out1)
|
| 271 |
+
out1_downsample = self.layer1_downsample(out1)
|
| 272 |
+
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
| 273 |
+
out3 = self.layer3(out2)
|
| 274 |
+
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
| 275 |
+
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
| 276 |
+
out4 = self.layer4(out3)
|
| 277 |
+
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
| 278 |
+
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1,end_dim=2).mean(-1)
|
| 279 |
+
return fuse_out1234
|
| 280 |
+
# print(fuse_out1234.shape)
|
| 281 |
+
# print(fuse_out1234.flatten(start_dim=1,end_dim=2).shape)
|
| 282 |
+
# pdb.set_trace()
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
|
eres2net/fusion.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AFF(nn.Module):
|
| 9 |
+
|
| 10 |
+
def __init__(self, channels=64, r=4):
|
| 11 |
+
super(AFF, self).__init__()
|
| 12 |
+
inter_channels = int(channels // r)
|
| 13 |
+
|
| 14 |
+
self.local_att = nn.Sequential(
|
| 15 |
+
nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0),
|
| 16 |
+
nn.BatchNorm2d(inter_channels),
|
| 17 |
+
nn.SiLU(inplace=True),
|
| 18 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
| 19 |
+
nn.BatchNorm2d(channels),
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
def forward(self, x, ds_y):
|
| 23 |
+
xa = torch.cat((x, ds_y), dim=1)
|
| 24 |
+
x_att = self.local_att(xa)
|
| 25 |
+
x_att = 1.0 + torch.tanh(x_att)
|
| 26 |
+
xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0-x_att)
|
| 27 |
+
|
| 28 |
+
return xo
|
| 29 |
+
|
eres2net/kaldi.py
ADDED
|
@@ -0,0 +1,819 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torchaudio
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"get_mel_banks",
|
| 10 |
+
"inverse_mel_scale",
|
| 11 |
+
"inverse_mel_scale_scalar",
|
| 12 |
+
"mel_scale",
|
| 13 |
+
"mel_scale_scalar",
|
| 14 |
+
"spectrogram",
|
| 15 |
+
"fbank",
|
| 16 |
+
"mfcc",
|
| 17 |
+
"vtln_warp_freq",
|
| 18 |
+
"vtln_warp_mel_freq",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
# numeric_limits<float>::epsilon() 1.1920928955078125e-07
|
| 22 |
+
EPSILON = torch.tensor(torch.finfo(torch.float).eps)
|
| 23 |
+
# 1 milliseconds = 0.001 seconds
|
| 24 |
+
MILLISECONDS_TO_SECONDS = 0.001
|
| 25 |
+
|
| 26 |
+
# window types
|
| 27 |
+
HAMMING = "hamming"
|
| 28 |
+
HANNING = "hanning"
|
| 29 |
+
POVEY = "povey"
|
| 30 |
+
RECTANGULAR = "rectangular"
|
| 31 |
+
BLACKMAN = "blackman"
|
| 32 |
+
WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _get_epsilon(device, dtype):
|
| 36 |
+
return EPSILON.to(device=device, dtype=dtype)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _next_power_of_2(x: int) -> int:
|
| 40 |
+
r"""Returns the smallest power of 2 that is greater than x"""
|
| 41 |
+
return 1 if x == 0 else 2 ** (x - 1).bit_length()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor:
|
| 45 |
+
r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``)
|
| 46 |
+
representing how the window is shifted along the waveform. Each row is a frame.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
waveform (Tensor): Tensor of size ``num_samples``
|
| 50 |
+
window_size (int): Frame length
|
| 51 |
+
window_shift (int): Frame shift
|
| 52 |
+
snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
|
| 53 |
+
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
| 54 |
+
depends only on the frame_shift, and we reflect the data at the ends.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame
|
| 58 |
+
"""
|
| 59 |
+
assert waveform.dim() == 1
|
| 60 |
+
num_samples = waveform.size(0)
|
| 61 |
+
strides = (window_shift * waveform.stride(0), waveform.stride(0))
|
| 62 |
+
|
| 63 |
+
if snip_edges:
|
| 64 |
+
if num_samples < window_size:
|
| 65 |
+
return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
|
| 66 |
+
else:
|
| 67 |
+
m = 1 + (num_samples - window_size) // window_shift
|
| 68 |
+
else:
|
| 69 |
+
reversed_waveform = torch.flip(waveform, [0])
|
| 70 |
+
m = (num_samples + (window_shift // 2)) // window_shift
|
| 71 |
+
pad = window_size // 2 - window_shift // 2
|
| 72 |
+
pad_right = reversed_waveform
|
| 73 |
+
if pad > 0:
|
| 74 |
+
# torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
|
| 75 |
+
# but we want [2, 1, 0, 0, 1, 2]
|
| 76 |
+
pad_left = reversed_waveform[-pad:]
|
| 77 |
+
waveform = torch.cat((pad_left, waveform, pad_right), dim=0)
|
| 78 |
+
else:
|
| 79 |
+
# pad is negative so we want to trim the waveform at the front
|
| 80 |
+
waveform = torch.cat((waveform[-pad:], pad_right), dim=0)
|
| 81 |
+
|
| 82 |
+
sizes = (m, window_size)
|
| 83 |
+
return waveform.as_strided(sizes, strides)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _feature_window_function(
|
| 87 |
+
window_type: str,
|
| 88 |
+
window_size: int,
|
| 89 |
+
blackman_coeff: float,
|
| 90 |
+
device: torch.device,
|
| 91 |
+
dtype: int,
|
| 92 |
+
) -> Tensor:
|
| 93 |
+
r"""Returns a window function with the given type and size"""
|
| 94 |
+
if window_type == HANNING:
|
| 95 |
+
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)
|
| 96 |
+
elif window_type == HAMMING:
|
| 97 |
+
return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype)
|
| 98 |
+
elif window_type == POVEY:
|
| 99 |
+
# like hanning but goes to zero at edges
|
| 100 |
+
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85)
|
| 101 |
+
elif window_type == RECTANGULAR:
|
| 102 |
+
return torch.ones(window_size, device=device, dtype=dtype)
|
| 103 |
+
elif window_type == BLACKMAN:
|
| 104 |
+
a = 2 * math.pi / (window_size - 1)
|
| 105 |
+
window_function = torch.arange(window_size, device=device, dtype=dtype)
|
| 106 |
+
# can't use torch.blackman_window as they use different coefficients
|
| 107 |
+
return (
|
| 108 |
+
blackman_coeff
|
| 109 |
+
- 0.5 * torch.cos(a * window_function)
|
| 110 |
+
+ (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)
|
| 111 |
+
).to(device=device, dtype=dtype)
|
| 112 |
+
else:
|
| 113 |
+
raise Exception("Invalid window type " + window_type)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor:
|
| 117 |
+
r"""Returns the log energy of size (m) for a strided_input (m,*)"""
|
| 118 |
+
device, dtype = strided_input.device, strided_input.dtype
|
| 119 |
+
log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
|
| 120 |
+
if energy_floor == 0.0:
|
| 121 |
+
return log_energy
|
| 122 |
+
return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _get_waveform_and_window_properties(
|
| 126 |
+
waveform: Tensor,
|
| 127 |
+
channel: int,
|
| 128 |
+
sample_frequency: float,
|
| 129 |
+
frame_shift: float,
|
| 130 |
+
frame_length: float,
|
| 131 |
+
round_to_power_of_two: bool,
|
| 132 |
+
preemphasis_coefficient: float,
|
| 133 |
+
) -> Tuple[Tensor, int, int, int]:
|
| 134 |
+
r"""Gets the waveform and window properties"""
|
| 135 |
+
channel = max(channel, 0)
|
| 136 |
+
assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0))
|
| 137 |
+
waveform = waveform[channel, :] # size (n)
|
| 138 |
+
window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
|
| 139 |
+
window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
|
| 140 |
+
padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
|
| 141 |
+
|
| 142 |
+
assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
|
| 143 |
+
window_size, len(waveform)
|
| 144 |
+
)
|
| 145 |
+
assert 0 < window_shift, "`window_shift` must be greater than 0"
|
| 146 |
+
assert padded_window_size % 2 == 0, (
|
| 147 |
+
"the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`"
|
| 148 |
+
)
|
| 149 |
+
assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
|
| 150 |
+
assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
|
| 151 |
+
return waveform, window_shift, window_size, padded_window_size
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def _get_window(
|
| 155 |
+
waveform: Tensor,
|
| 156 |
+
padded_window_size: int,
|
| 157 |
+
window_size: int,
|
| 158 |
+
window_shift: int,
|
| 159 |
+
window_type: str,
|
| 160 |
+
blackman_coeff: float,
|
| 161 |
+
snip_edges: bool,
|
| 162 |
+
raw_energy: bool,
|
| 163 |
+
energy_floor: float,
|
| 164 |
+
dither: float,
|
| 165 |
+
remove_dc_offset: bool,
|
| 166 |
+
preemphasis_coefficient: float,
|
| 167 |
+
) -> Tuple[Tensor, Tensor]:
|
| 168 |
+
r"""Gets a window and its log energy
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
(Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m)
|
| 172 |
+
"""
|
| 173 |
+
device, dtype = waveform.device, waveform.dtype
|
| 174 |
+
epsilon = _get_epsilon(device, dtype)
|
| 175 |
+
|
| 176 |
+
# size (m, window_size)
|
| 177 |
+
strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
|
| 178 |
+
|
| 179 |
+
if dither != 0.0:
|
| 180 |
+
rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype)
|
| 181 |
+
strided_input = strided_input + rand_gauss * dither
|
| 182 |
+
|
| 183 |
+
if remove_dc_offset:
|
| 184 |
+
# Subtract each row/frame by its mean
|
| 185 |
+
row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
|
| 186 |
+
strided_input = strided_input - row_means
|
| 187 |
+
|
| 188 |
+
if raw_energy:
|
| 189 |
+
# Compute the log energy of each row/frame before applying preemphasis and
|
| 190 |
+
# window function
|
| 191 |
+
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
|
| 192 |
+
|
| 193 |
+
if preemphasis_coefficient != 0.0:
|
| 194 |
+
# strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
|
| 195 |
+
offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(
|
| 196 |
+
0
|
| 197 |
+
) # size (m, window_size + 1)
|
| 198 |
+
strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]
|
| 199 |
+
|
| 200 |
+
# Apply window_function to each row/frame
|
| 201 |
+
window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze(
|
| 202 |
+
0
|
| 203 |
+
) # size (1, window_size)
|
| 204 |
+
strided_input = strided_input * window_function # size (m, window_size)
|
| 205 |
+
|
| 206 |
+
# Pad columns with zero until we reach size (m, padded_window_size)
|
| 207 |
+
if padded_window_size != window_size:
|
| 208 |
+
padding_right = padded_window_size - window_size
|
| 209 |
+
strided_input = torch.nn.functional.pad(
|
| 210 |
+
strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0
|
| 211 |
+
).squeeze(0)
|
| 212 |
+
|
| 213 |
+
# Compute energy after window function (not the raw one)
|
| 214 |
+
if not raw_energy:
|
| 215 |
+
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
|
| 216 |
+
|
| 217 |
+
return strided_input, signal_log_energy
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
|
| 221 |
+
# subtracts the column mean of the tensor size (m, n) if subtract_mean=True
|
| 222 |
+
# it returns size (m, n)
|
| 223 |
+
if subtract_mean:
|
| 224 |
+
col_means = torch.mean(tensor, dim=0).unsqueeze(0)
|
| 225 |
+
tensor = tensor - col_means
|
| 226 |
+
return tensor
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def spectrogram(
|
| 230 |
+
waveform: Tensor,
|
| 231 |
+
blackman_coeff: float = 0.42,
|
| 232 |
+
channel: int = -1,
|
| 233 |
+
dither: float = 0.0,
|
| 234 |
+
energy_floor: float = 1.0,
|
| 235 |
+
frame_length: float = 25.0,
|
| 236 |
+
frame_shift: float = 10.0,
|
| 237 |
+
min_duration: float = 0.0,
|
| 238 |
+
preemphasis_coefficient: float = 0.97,
|
| 239 |
+
raw_energy: bool = True,
|
| 240 |
+
remove_dc_offset: bool = True,
|
| 241 |
+
round_to_power_of_two: bool = True,
|
| 242 |
+
sample_frequency: float = 16000.0,
|
| 243 |
+
snip_edges: bool = True,
|
| 244 |
+
subtract_mean: bool = False,
|
| 245 |
+
window_type: str = POVEY,
|
| 246 |
+
) -> Tensor:
|
| 247 |
+
r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's
|
| 248 |
+
compute-spectrogram-feats.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
| 252 |
+
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
|
| 253 |
+
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
|
| 254 |
+
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
|
| 255 |
+
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
|
| 256 |
+
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
|
| 257 |
+
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
|
| 258 |
+
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
|
| 259 |
+
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
|
| 260 |
+
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
|
| 261 |
+
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
|
| 262 |
+
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
|
| 263 |
+
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
|
| 264 |
+
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
|
| 265 |
+
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
| 266 |
+
to FFT. (Default: ``True``)
|
| 267 |
+
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
|
| 268 |
+
specified there) (Default: ``16000.0``)
|
| 269 |
+
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
|
| 270 |
+
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
| 271 |
+
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
|
| 272 |
+
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
|
| 273 |
+
it this way. (Default: ``False``)
|
| 274 |
+
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
|
| 275 |
+
(Default: ``'povey'``)
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
Tensor: A spectrogram identical to what Kaldi would output. The shape is
|
| 279 |
+
(m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided
|
| 280 |
+
"""
|
| 281 |
+
device, dtype = waveform.device, waveform.dtype
|
| 282 |
+
epsilon = _get_epsilon(device, dtype)
|
| 283 |
+
|
| 284 |
+
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
|
| 285 |
+
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if len(waveform) < min_duration * sample_frequency:
|
| 289 |
+
# signal is too short
|
| 290 |
+
return torch.empty(0)
|
| 291 |
+
|
| 292 |
+
strided_input, signal_log_energy = _get_window(
|
| 293 |
+
waveform,
|
| 294 |
+
padded_window_size,
|
| 295 |
+
window_size,
|
| 296 |
+
window_shift,
|
| 297 |
+
window_type,
|
| 298 |
+
blackman_coeff,
|
| 299 |
+
snip_edges,
|
| 300 |
+
raw_energy,
|
| 301 |
+
energy_floor,
|
| 302 |
+
dither,
|
| 303 |
+
remove_dc_offset,
|
| 304 |
+
preemphasis_coefficient,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# size (m, padded_window_size // 2 + 1, 2)
|
| 308 |
+
fft = torch.fft.rfft(strided_input)
|
| 309 |
+
|
| 310 |
+
# Convert the FFT into a power spectrum
|
| 311 |
+
power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1)
|
| 312 |
+
power_spectrum[:, 0] = signal_log_energy
|
| 313 |
+
|
| 314 |
+
power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
|
| 315 |
+
return power_spectrum
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def inverse_mel_scale_scalar(mel_freq: float) -> float:
|
| 319 |
+
return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
|
| 323 |
+
return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def mel_scale_scalar(freq: float) -> float:
|
| 327 |
+
return 1127.0 * math.log(1.0 + freq / 700.0)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def mel_scale(freq: Tensor) -> Tensor:
|
| 331 |
+
return 1127.0 * (1.0 + freq / 700.0).log()
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def vtln_warp_freq(
|
| 335 |
+
vtln_low_cutoff: float,
|
| 336 |
+
vtln_high_cutoff: float,
|
| 337 |
+
low_freq: float,
|
| 338 |
+
high_freq: float,
|
| 339 |
+
vtln_warp_factor: float,
|
| 340 |
+
freq: Tensor,
|
| 341 |
+
) -> Tensor:
|
| 342 |
+
r"""This computes a VTLN warping function that is not the same as HTK's one,
|
| 343 |
+
but has similar inputs (this function has the advantage of never producing
|
| 344 |
+
empty bins).
|
| 345 |
+
|
| 346 |
+
This function computes a warp function F(freq), defined between low_freq
|
| 347 |
+
and high_freq inclusive, with the following properties:
|
| 348 |
+
F(low_freq) == low_freq
|
| 349 |
+
F(high_freq) == high_freq
|
| 350 |
+
The function is continuous and piecewise linear with two inflection
|
| 351 |
+
points.
|
| 352 |
+
The lower inflection point (measured in terms of the unwarped
|
| 353 |
+
frequency) is at frequency l, determined as described below.
|
| 354 |
+
The higher inflection point is at a frequency h, determined as
|
| 355 |
+
described below.
|
| 356 |
+
If l <= f <= h, then F(f) = f/vtln_warp_factor.
|
| 357 |
+
If the higher inflection point (measured in terms of the unwarped
|
| 358 |
+
frequency) is at h, then max(h, F(h)) == vtln_high_cutoff.
|
| 359 |
+
Since (by the last point) F(h) == h/vtln_warp_factor, then
|
| 360 |
+
max(h, h/vtln_warp_factor) == vtln_high_cutoff, so
|
| 361 |
+
h = vtln_high_cutoff / max(1, 1/vtln_warp_factor).
|
| 362 |
+
= vtln_high_cutoff * min(1, vtln_warp_factor).
|
| 363 |
+
If the lower inflection point (measured in terms of the unwarped
|
| 364 |
+
frequency) is at l, then min(l, F(l)) == vtln_low_cutoff
|
| 365 |
+
This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor)
|
| 366 |
+
= vtln_low_cutoff * max(1, vtln_warp_factor)
|
| 367 |
+
Args:
|
| 368 |
+
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
|
| 369 |
+
vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
|
| 370 |
+
low_freq (float): Lower frequency cutoffs in mel computation
|
| 371 |
+
high_freq (float): Upper frequency cutoffs in mel computation
|
| 372 |
+
vtln_warp_factor (float): Vtln warp factor
|
| 373 |
+
freq (Tensor): given frequency in Hz
|
| 374 |
+
|
| 375 |
+
Returns:
|
| 376 |
+
Tensor: Freq after vtln warp
|
| 377 |
+
"""
|
| 378 |
+
assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq"
|
| 379 |
+
assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]"
|
| 380 |
+
l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
|
| 381 |
+
h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
|
| 382 |
+
scale = 1.0 / vtln_warp_factor
|
| 383 |
+
Fl = scale * l # F(l)
|
| 384 |
+
Fh = scale * h # F(h)
|
| 385 |
+
assert l > low_freq and h < high_freq
|
| 386 |
+
# slope of left part of the 3-piece linear function
|
| 387 |
+
scale_left = (Fl - low_freq) / (l - low_freq)
|
| 388 |
+
# [slope of center part is just "scale"]
|
| 389 |
+
|
| 390 |
+
# slope of right part of the 3-piece linear function
|
| 391 |
+
scale_right = (high_freq - Fh) / (high_freq - h)
|
| 392 |
+
|
| 393 |
+
res = torch.empty_like(freq)
|
| 394 |
+
|
| 395 |
+
outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq
|
| 396 |
+
before_l = torch.lt(freq, l) # freq < l
|
| 397 |
+
before_h = torch.lt(freq, h) # freq < h
|
| 398 |
+
after_h = torch.ge(freq, h) # freq >= h
|
| 399 |
+
|
| 400 |
+
# order of operations matter here (since there is overlapping frequency regions)
|
| 401 |
+
res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
|
| 402 |
+
res[before_h] = scale * freq[before_h]
|
| 403 |
+
res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
|
| 404 |
+
res[outside_low_high_freq] = freq[outside_low_high_freq]
|
| 405 |
+
|
| 406 |
+
return res
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def vtln_warp_mel_freq(
|
| 410 |
+
vtln_low_cutoff: float,
|
| 411 |
+
vtln_high_cutoff: float,
|
| 412 |
+
low_freq,
|
| 413 |
+
high_freq: float,
|
| 414 |
+
vtln_warp_factor: float,
|
| 415 |
+
mel_freq: Tensor,
|
| 416 |
+
) -> Tensor:
|
| 417 |
+
r"""
|
| 418 |
+
Args:
|
| 419 |
+
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
|
| 420 |
+
vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
|
| 421 |
+
low_freq (float): Lower frequency cutoffs in mel computation
|
| 422 |
+
high_freq (float): Upper frequency cutoffs in mel computation
|
| 423 |
+
vtln_warp_factor (float): Vtln warp factor
|
| 424 |
+
mel_freq (Tensor): Given frequency in Mel
|
| 425 |
+
|
| 426 |
+
Returns:
|
| 427 |
+
Tensor: ``mel_freq`` after vtln warp
|
| 428 |
+
"""
|
| 429 |
+
return mel_scale(
|
| 430 |
+
vtln_warp_freq(
|
| 431 |
+
vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq)
|
| 432 |
+
)
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def get_mel_banks(
|
| 437 |
+
num_bins: int,
|
| 438 |
+
window_length_padded: int,
|
| 439 |
+
sample_freq: float,
|
| 440 |
+
low_freq: float,
|
| 441 |
+
high_freq: float,
|
| 442 |
+
vtln_low: float,
|
| 443 |
+
vtln_high: float,
|
| 444 |
+
vtln_warp_factor: float,device=None,dtype=None
|
| 445 |
+
) -> Tuple[Tensor, Tensor]:
|
| 446 |
+
"""
|
| 447 |
+
Returns:
|
| 448 |
+
(Tensor, Tensor): The tuple consists of ``bins`` (which is
|
| 449 |
+
melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is
|
| 450 |
+
center frequencies of bins of size (``num_bins``)).
|
| 451 |
+
"""
|
| 452 |
+
assert num_bins > 3, "Must have at least 3 mel bins"
|
| 453 |
+
assert window_length_padded % 2 == 0
|
| 454 |
+
num_fft_bins = window_length_padded / 2
|
| 455 |
+
nyquist = 0.5 * sample_freq
|
| 456 |
+
|
| 457 |
+
if high_freq <= 0.0:
|
| 458 |
+
high_freq += nyquist
|
| 459 |
+
|
| 460 |
+
assert (
|
| 461 |
+
(0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq)
|
| 462 |
+
), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
|
| 463 |
+
|
| 464 |
+
# fft-bin width [think of it as Nyquist-freq / half-window-length]
|
| 465 |
+
fft_bin_width = sample_freq / window_length_padded
|
| 466 |
+
mel_low_freq = mel_scale_scalar(low_freq)
|
| 467 |
+
mel_high_freq = mel_scale_scalar(high_freq)
|
| 468 |
+
|
| 469 |
+
# divide by num_bins+1 in next line because of end-effects where the bins
|
| 470 |
+
# spread out to the sides.
|
| 471 |
+
mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
|
| 472 |
+
|
| 473 |
+
if vtln_high < 0.0:
|
| 474 |
+
vtln_high += nyquist
|
| 475 |
+
|
| 476 |
+
assert vtln_warp_factor == 1.0 or (
|
| 477 |
+
(low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)
|
| 478 |
+
), "Bad values in options: vtln-low {} and vtln-high {}, versus " "low-freq {} and high-freq {}".format(
|
| 479 |
+
vtln_low, vtln_high, low_freq, high_freq
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
bin = torch.arange(num_bins).unsqueeze(1)
|
| 483 |
+
left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
|
| 484 |
+
center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1)
|
| 485 |
+
right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1)
|
| 486 |
+
|
| 487 |
+
if vtln_warp_factor != 1.0:
|
| 488 |
+
left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel)
|
| 489 |
+
center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel)
|
| 490 |
+
right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel)
|
| 491 |
+
|
| 492 |
+
# center_freqs = inverse_mel_scale(center_mel) # size (num_bins)
|
| 493 |
+
# size(1, num_fft_bins)
|
| 494 |
+
mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0)
|
| 495 |
+
|
| 496 |
+
# size (num_bins, num_fft_bins)
|
| 497 |
+
up_slope = (mel - left_mel) / (center_mel - left_mel)
|
| 498 |
+
down_slope = (right_mel - mel) / (right_mel - center_mel)
|
| 499 |
+
|
| 500 |
+
if vtln_warp_factor == 1.0:
|
| 501 |
+
# left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
|
| 502 |
+
bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
|
| 503 |
+
else:
|
| 504 |
+
# warping can move the order of left_mel, center_mel, right_mel anywhere
|
| 505 |
+
bins = torch.zeros_like(up_slope)
|
| 506 |
+
up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel
|
| 507 |
+
down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel
|
| 508 |
+
bins[up_idx] = up_slope[up_idx]
|
| 509 |
+
bins[down_idx] = down_slope[down_idx]
|
| 510 |
+
|
| 511 |
+
return bins.to(device=device,dtype=dtype)#, center_freqs
|
| 512 |
+
|
| 513 |
+
cache={}
|
| 514 |
+
def fbank(
|
| 515 |
+
waveform: Tensor,
|
| 516 |
+
blackman_coeff: float = 0.42,
|
| 517 |
+
channel: int = -1,
|
| 518 |
+
dither: float = 0.0,
|
| 519 |
+
energy_floor: float = 1.0,
|
| 520 |
+
frame_length: float = 25.0,
|
| 521 |
+
frame_shift: float = 10.0,
|
| 522 |
+
high_freq: float = 0.0,
|
| 523 |
+
htk_compat: bool = False,
|
| 524 |
+
low_freq: float = 20.0,
|
| 525 |
+
min_duration: float = 0.0,
|
| 526 |
+
num_mel_bins: int = 23,
|
| 527 |
+
preemphasis_coefficient: float = 0.97,
|
| 528 |
+
raw_energy: bool = True,
|
| 529 |
+
remove_dc_offset: bool = True,
|
| 530 |
+
round_to_power_of_two: bool = True,
|
| 531 |
+
sample_frequency: float = 16000.0,
|
| 532 |
+
snip_edges: bool = True,
|
| 533 |
+
subtract_mean: bool = False,
|
| 534 |
+
use_energy: bool = False,
|
| 535 |
+
use_log_fbank: bool = True,
|
| 536 |
+
use_power: bool = True,
|
| 537 |
+
vtln_high: float = -500.0,
|
| 538 |
+
vtln_low: float = 100.0,
|
| 539 |
+
vtln_warp: float = 1.0,
|
| 540 |
+
window_type: str = POVEY,
|
| 541 |
+
) -> Tensor:
|
| 542 |
+
r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's
|
| 543 |
+
compute-fbank-feats.
|
| 544 |
+
|
| 545 |
+
Args:
|
| 546 |
+
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
| 547 |
+
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
|
| 548 |
+
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
|
| 549 |
+
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
|
| 550 |
+
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
|
| 551 |
+
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
|
| 552 |
+
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
|
| 553 |
+
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
|
| 554 |
+
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
|
| 555 |
+
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
|
| 556 |
+
high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
|
| 557 |
+
(Default: ``0.0``)
|
| 558 |
+
htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features
|
| 559 |
+
(need to change other parameters). (Default: ``False``)
|
| 560 |
+
low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
|
| 561 |
+
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
|
| 562 |
+
num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
|
| 563 |
+
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
|
| 564 |
+
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
|
| 565 |
+
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
|
| 566 |
+
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
| 567 |
+
to FFT. (Default: ``True``)
|
| 568 |
+
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
|
| 569 |
+
specified there) (Default: ``16000.0``)
|
| 570 |
+
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
|
| 571 |
+
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
| 572 |
+
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
|
| 573 |
+
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
|
| 574 |
+
it this way. (Default: ``False``)
|
| 575 |
+
use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
|
| 576 |
+
use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``)
|
| 577 |
+
use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``)
|
| 578 |
+
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
|
| 579 |
+
negative, offset from high-mel-freq (Default: ``-500.0``)
|
| 580 |
+
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
|
| 581 |
+
vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
|
| 582 |
+
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
|
| 583 |
+
(Default: ``'povey'``)
|
| 584 |
+
|
| 585 |
+
Returns:
|
| 586 |
+
Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``)
|
| 587 |
+
where m is calculated in _get_strided
|
| 588 |
+
"""
|
| 589 |
+
device, dtype = waveform.device, waveform.dtype
|
| 590 |
+
|
| 591 |
+
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
|
| 592 |
+
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
if len(waveform) < min_duration * sample_frequency:
|
| 596 |
+
# signal is too short
|
| 597 |
+
return torch.empty(0, device=device, dtype=dtype)
|
| 598 |
+
|
| 599 |
+
# strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
|
| 600 |
+
strided_input, signal_log_energy = _get_window(
|
| 601 |
+
waveform,
|
| 602 |
+
padded_window_size,
|
| 603 |
+
window_size,
|
| 604 |
+
window_shift,
|
| 605 |
+
window_type,
|
| 606 |
+
blackman_coeff,
|
| 607 |
+
snip_edges,
|
| 608 |
+
raw_energy,
|
| 609 |
+
energy_floor,
|
| 610 |
+
dither,
|
| 611 |
+
remove_dc_offset,
|
| 612 |
+
preemphasis_coefficient,
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
# size (m, padded_window_size // 2 + 1)
|
| 616 |
+
spectrum = torch.fft.rfft(strided_input).abs()
|
| 617 |
+
if use_power:
|
| 618 |
+
spectrum = spectrum.pow(2.0)
|
| 619 |
+
|
| 620 |
+
# size (num_mel_bins, padded_window_size // 2)
|
| 621 |
+
# print(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp)
|
| 622 |
+
|
| 623 |
+
cache_key="%s-%s-%s-%s-%s-%s-%s-%s-%s-%s"%(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp,device,dtype)
|
| 624 |
+
if cache_key not in cache:
|
| 625 |
+
mel_energies = get_mel_banks(
|
| 626 |
+
num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp,device,dtype
|
| 627 |
+
)
|
| 628 |
+
cache[cache_key]=mel_energies
|
| 629 |
+
else:
|
| 630 |
+
mel_energies=cache[cache_key]
|
| 631 |
+
|
| 632 |
+
# pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
|
| 633 |
+
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
|
| 634 |
+
|
| 635 |
+
# sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
|
| 636 |
+
mel_energies = torch.mm(spectrum, mel_energies.T)
|
| 637 |
+
if use_log_fbank:
|
| 638 |
+
# avoid log of zero (which should be prevented anyway by dithering)
|
| 639 |
+
mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
|
| 640 |
+
|
| 641 |
+
# if use_energy then add it as the last column for htk_compat == true else first column
|
| 642 |
+
if use_energy:
|
| 643 |
+
signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1)
|
| 644 |
+
# returns size (m, num_mel_bins + 1)
|
| 645 |
+
if htk_compat:
|
| 646 |
+
mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1)
|
| 647 |
+
else:
|
| 648 |
+
mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1)
|
| 649 |
+
|
| 650 |
+
mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
|
| 651 |
+
return mel_energies
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor:
|
| 655 |
+
# returns a dct matrix of size (num_mel_bins, num_ceps)
|
| 656 |
+
# size (num_mel_bins, num_mel_bins)
|
| 657 |
+
dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho")
|
| 658 |
+
# kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
|
| 659 |
+
# this would be the first column in the dct_matrix for torchaudio as it expects a
|
| 660 |
+
# right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
|
| 661 |
+
# expects a left multiply e.g. dct_matrix * vector).
|
| 662 |
+
dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins))
|
| 663 |
+
dct_matrix = dct_matrix[:, :num_ceps]
|
| 664 |
+
return dct_matrix
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor:
|
| 668 |
+
# returns size (num_ceps)
|
| 669 |
+
# Compute liftering coefficients (scaling on cepstral coeffs)
|
| 670 |
+
# coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
|
| 671 |
+
i = torch.arange(num_ceps)
|
| 672 |
+
return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter)
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
def mfcc(
|
| 676 |
+
waveform: Tensor,
|
| 677 |
+
blackman_coeff: float = 0.42,
|
| 678 |
+
cepstral_lifter: float = 22.0,
|
| 679 |
+
channel: int = -1,
|
| 680 |
+
dither: float = 0.0,
|
| 681 |
+
energy_floor: float = 1.0,
|
| 682 |
+
frame_length: float = 25.0,
|
| 683 |
+
frame_shift: float = 10.0,
|
| 684 |
+
high_freq: float = 0.0,
|
| 685 |
+
htk_compat: bool = False,
|
| 686 |
+
low_freq: float = 20.0,
|
| 687 |
+
num_ceps: int = 13,
|
| 688 |
+
min_duration: float = 0.0,
|
| 689 |
+
num_mel_bins: int = 23,
|
| 690 |
+
preemphasis_coefficient: float = 0.97,
|
| 691 |
+
raw_energy: bool = True,
|
| 692 |
+
remove_dc_offset: bool = True,
|
| 693 |
+
round_to_power_of_two: bool = True,
|
| 694 |
+
sample_frequency: float = 16000.0,
|
| 695 |
+
snip_edges: bool = True,
|
| 696 |
+
subtract_mean: bool = False,
|
| 697 |
+
use_energy: bool = False,
|
| 698 |
+
vtln_high: float = -500.0,
|
| 699 |
+
vtln_low: float = 100.0,
|
| 700 |
+
vtln_warp: float = 1.0,
|
| 701 |
+
window_type: str = POVEY,
|
| 702 |
+
) -> Tensor:
|
| 703 |
+
r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's
|
| 704 |
+
compute-mfcc-feats.
|
| 705 |
+
|
| 706 |
+
Args:
|
| 707 |
+
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
| 708 |
+
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
|
| 709 |
+
cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``)
|
| 710 |
+
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
|
| 711 |
+
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
|
| 712 |
+
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
|
| 713 |
+
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
|
| 714 |
+
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
|
| 715 |
+
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
|
| 716 |
+
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
|
| 717 |
+
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
|
| 718 |
+
high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
|
| 719 |
+
(Default: ``0.0``)
|
| 720 |
+
htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible
|
| 721 |
+
features (need to change other parameters). (Default: ``False``)
|
| 722 |
+
low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
|
| 723 |
+
num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``)
|
| 724 |
+
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
|
| 725 |
+
num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
|
| 726 |
+
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
|
| 727 |
+
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
|
| 728 |
+
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
|
| 729 |
+
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
| 730 |
+
to FFT. (Default: ``True``)
|
| 731 |
+
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
|
| 732 |
+
specified there) (Default: ``16000.0``)
|
| 733 |
+
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
|
| 734 |
+
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
| 735 |
+
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
|
| 736 |
+
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
|
| 737 |
+
it this way. (Default: ``False``)
|
| 738 |
+
use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
|
| 739 |
+
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
|
| 740 |
+
negative, offset from high-mel-freq (Default: ``-500.0``)
|
| 741 |
+
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
|
| 742 |
+
vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
|
| 743 |
+
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
|
| 744 |
+
(Default: ``"povey"``)
|
| 745 |
+
|
| 746 |
+
Returns:
|
| 747 |
+
Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``)
|
| 748 |
+
where m is calculated in _get_strided
|
| 749 |
+
"""
|
| 750 |
+
assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins)
|
| 751 |
+
|
| 752 |
+
device, dtype = waveform.device, waveform.dtype
|
| 753 |
+
|
| 754 |
+
# The mel_energies should not be squared (use_power=True), not have mean subtracted
|
| 755 |
+
# (subtract_mean=False), and use log (use_log_fbank=True).
|
| 756 |
+
# size (m, num_mel_bins + use_energy)
|
| 757 |
+
feature = fbank(
|
| 758 |
+
waveform=waveform,
|
| 759 |
+
blackman_coeff=blackman_coeff,
|
| 760 |
+
channel=channel,
|
| 761 |
+
dither=dither,
|
| 762 |
+
energy_floor=energy_floor,
|
| 763 |
+
frame_length=frame_length,
|
| 764 |
+
frame_shift=frame_shift,
|
| 765 |
+
high_freq=high_freq,
|
| 766 |
+
htk_compat=htk_compat,
|
| 767 |
+
low_freq=low_freq,
|
| 768 |
+
min_duration=min_duration,
|
| 769 |
+
num_mel_bins=num_mel_bins,
|
| 770 |
+
preemphasis_coefficient=preemphasis_coefficient,
|
| 771 |
+
raw_energy=raw_energy,
|
| 772 |
+
remove_dc_offset=remove_dc_offset,
|
| 773 |
+
round_to_power_of_two=round_to_power_of_two,
|
| 774 |
+
sample_frequency=sample_frequency,
|
| 775 |
+
snip_edges=snip_edges,
|
| 776 |
+
subtract_mean=False,
|
| 777 |
+
use_energy=use_energy,
|
| 778 |
+
use_log_fbank=True,
|
| 779 |
+
use_power=True,
|
| 780 |
+
vtln_high=vtln_high,
|
| 781 |
+
vtln_low=vtln_low,
|
| 782 |
+
vtln_warp=vtln_warp,
|
| 783 |
+
window_type=window_type,
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
if use_energy:
|
| 787 |
+
# size (m)
|
| 788 |
+
signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
|
| 789 |
+
# offset is 0 if htk_compat==True else 1
|
| 790 |
+
mel_offset = int(not htk_compat)
|
| 791 |
+
feature = feature[:, mel_offset : (num_mel_bins + mel_offset)]
|
| 792 |
+
|
| 793 |
+
# size (num_mel_bins, num_ceps)
|
| 794 |
+
dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device)
|
| 795 |
+
|
| 796 |
+
# size (m, num_ceps)
|
| 797 |
+
feature = feature.matmul(dct_matrix)
|
| 798 |
+
|
| 799 |
+
if cepstral_lifter != 0.0:
|
| 800 |
+
# size (1, num_ceps)
|
| 801 |
+
lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0)
|
| 802 |
+
feature *= lifter_coeffs.to(device=device, dtype=dtype)
|
| 803 |
+
|
| 804 |
+
# if use_energy then replace the last column for htk_compat == true else first column
|
| 805 |
+
if use_energy:
|
| 806 |
+
feature[:, 0] = signal_log_energy
|
| 807 |
+
|
| 808 |
+
if htk_compat:
|
| 809 |
+
energy = feature[:, 0].unsqueeze(1) # size (m, 1)
|
| 810 |
+
feature = feature[:, 1:] # size (m, num_ceps - 1)
|
| 811 |
+
if not use_energy:
|
| 812 |
+
# scale on C0 (actually removing a scale we previously added that's
|
| 813 |
+
# part of one common definition of the cosine transform.)
|
| 814 |
+
energy *= math.sqrt(2)
|
| 815 |
+
|
| 816 |
+
feature = torch.cat((feature, energy), dim=1)
|
| 817 |
+
|
| 818 |
+
feature = _subtract_column_mean(feature, subtract_mean)
|
| 819 |
+
return feature
|
eres2net/pooling_layers.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 3 |
+
|
| 4 |
+
""" This implementation is adapted from https://github.com/wenet-e2e/wespeaker."""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TAP(nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
Temporal average pooling, only first-order mean is considered
|
| 13 |
+
"""
|
| 14 |
+
def __init__(self, **kwargs):
|
| 15 |
+
super(TAP, self).__init__()
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
pooling_mean = x.mean(dim=-1)
|
| 19 |
+
# To be compatable with 2D input
|
| 20 |
+
pooling_mean = pooling_mean.flatten(start_dim=1)
|
| 21 |
+
return pooling_mean
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TSDP(nn.Module):
|
| 25 |
+
"""
|
| 26 |
+
Temporal standard deviation pooling, only second-order std is considered
|
| 27 |
+
"""
|
| 28 |
+
def __init__(self, **kwargs):
|
| 29 |
+
super(TSDP, self).__init__()
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
# The last dimension is the temporal axis
|
| 33 |
+
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
|
| 34 |
+
pooling_std = pooling_std.flatten(start_dim=1)
|
| 35 |
+
return pooling_std
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class TSTP(nn.Module):
|
| 39 |
+
"""
|
| 40 |
+
Temporal statistics pooling, concatenate mean and std, which is used in
|
| 41 |
+
x-vector
|
| 42 |
+
Comment: simple concatenation can not make full use of both statistics
|
| 43 |
+
"""
|
| 44 |
+
def __init__(self, **kwargs):
|
| 45 |
+
super(TSTP, self).__init__()
|
| 46 |
+
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
# The last dimension is the temporal axis
|
| 49 |
+
pooling_mean = x.mean(dim=-1)
|
| 50 |
+
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
|
| 51 |
+
pooling_mean = pooling_mean.flatten(start_dim=1)
|
| 52 |
+
pooling_std = pooling_std.flatten(start_dim=1)
|
| 53 |
+
|
| 54 |
+
stats = torch.cat((pooling_mean, pooling_std), 1)
|
| 55 |
+
return stats
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class ASTP(nn.Module):
|
| 59 |
+
""" Attentive statistics pooling: Channel- and context-dependent
|
| 60 |
+
statistics pooling, first used in ECAPA_TDNN.
|
| 61 |
+
"""
|
| 62 |
+
def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
|
| 63 |
+
super(ASTP, self).__init__()
|
| 64 |
+
self.global_context_att = global_context_att
|
| 65 |
+
|
| 66 |
+
# Use Conv1d with stride == 1 rather than Linear, then we don't
|
| 67 |
+
# need to transpose inputs.
|
| 68 |
+
if global_context_att:
|
| 69 |
+
self.linear1 = nn.Conv1d(
|
| 70 |
+
in_dim * 3, bottleneck_dim,
|
| 71 |
+
kernel_size=1) # equals W and b in the paper
|
| 72 |
+
else:
|
| 73 |
+
self.linear1 = nn.Conv1d(
|
| 74 |
+
in_dim, bottleneck_dim,
|
| 75 |
+
kernel_size=1) # equals W and b in the paper
|
| 76 |
+
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
|
| 77 |
+
kernel_size=1) # equals V and k in the paper
|
| 78 |
+
|
| 79 |
+
def forward(self, x):
|
| 80 |
+
"""
|
| 81 |
+
x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
|
| 82 |
+
or a 4-dimensional tensor in resnet architecture (B,C,F,T)
|
| 83 |
+
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
|
| 84 |
+
"""
|
| 85 |
+
if len(x.shape) == 4:
|
| 86 |
+
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
|
| 87 |
+
assert len(x.shape) == 3
|
| 88 |
+
|
| 89 |
+
if self.global_context_att:
|
| 90 |
+
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
| 91 |
+
context_std = torch.sqrt(
|
| 92 |
+
torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
|
| 93 |
+
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
| 94 |
+
else:
|
| 95 |
+
x_in = x
|
| 96 |
+
|
| 97 |
+
# DON'T use ReLU here! ReLU may be hard to converge.
|
| 98 |
+
alpha = torch.tanh(
|
| 99 |
+
self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
|
| 100 |
+
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
| 101 |
+
mean = torch.sum(alpha * x, dim=2)
|
| 102 |
+
var = torch.sum(alpha * (x**2), dim=2) - mean**2
|
| 103 |
+
std = torch.sqrt(var.clamp(min=1e-10))
|
| 104 |
+
return torch.cat([mean, std], dim=1)
|
feature_extractor/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import cnhubert, whisper_enc
|
| 2 |
+
|
| 3 |
+
content_module_map = {
|
| 4 |
+
'cnhubert': cnhubert,
|
| 5 |
+
'whisper': whisper_enc
|
| 6 |
+
}
|
feature_extractor/cnhubert.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
|
| 3 |
+
import librosa
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import soundfile as sf
|
| 7 |
+
import os
|
| 8 |
+
from transformers import logging as tf_logging
|
| 9 |
+
tf_logging.set_verbosity_error()
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
logging.getLogger("numba").setLevel(logging.WARNING)
|
| 13 |
+
|
| 14 |
+
from transformers import (
|
| 15 |
+
Wav2Vec2FeatureExtractor,
|
| 16 |
+
HubertModel,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
import utils
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
|
| 22 |
+
cnhubert_base_path = None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class CNHubert(nn.Module):
|
| 26 |
+
def __init__(self):
|
| 27 |
+
super().__init__()
|
| 28 |
+
if os.path.exists(cnhubert_base_path):...
|
| 29 |
+
else:raise FileNotFoundError(cnhubert_base_path)
|
| 30 |
+
self.model = HubertModel.from_pretrained(cnhubert_base_path, local_files_only=True)
|
| 31 |
+
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
| 32 |
+
cnhubert_base_path, local_files_only=True
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
input_values = self.feature_extractor(
|
| 37 |
+
x, return_tensors="pt", sampling_rate=16000
|
| 38 |
+
).input_values.to(x.device)
|
| 39 |
+
feats = self.model(input_values)["last_hidden_state"]
|
| 40 |
+
return feats
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# class CNHubertLarge(nn.Module):
|
| 44 |
+
# def __init__(self):
|
| 45 |
+
# super().__init__()
|
| 46 |
+
# self.model = HubertModel.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-hubert-large")
|
| 47 |
+
# self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-hubert-large")
|
| 48 |
+
# def forward(self, x):
|
| 49 |
+
# input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
|
| 50 |
+
# feats = self.model(input_values)["last_hidden_state"]
|
| 51 |
+
# return feats
|
| 52 |
+
#
|
| 53 |
+
# class CVec(nn.Module):
|
| 54 |
+
# def __init__(self):
|
| 55 |
+
# super().__init__()
|
| 56 |
+
# self.model = HubertModel.from_pretrained("/data/docker/liujing04/vc-webui-big/hubert_base")
|
| 57 |
+
# self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/vc-webui-big/hubert_base")
|
| 58 |
+
# def forward(self, x):
|
| 59 |
+
# input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
|
| 60 |
+
# feats = self.model(input_values)["last_hidden_state"]
|
| 61 |
+
# return feats
|
| 62 |
+
#
|
| 63 |
+
# class cnw2v2base(nn.Module):
|
| 64 |
+
# def __init__(self):
|
| 65 |
+
# super().__init__()
|
| 66 |
+
# self.model = Wav2Vec2Model.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-wav2vec2-base")
|
| 67 |
+
# self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-wav2vec2-base")
|
| 68 |
+
# def forward(self, x):
|
| 69 |
+
# input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
|
| 70 |
+
# feats = self.model(input_values)["last_hidden_state"]
|
| 71 |
+
# return feats
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_model():
|
| 75 |
+
model = CNHubert()
|
| 76 |
+
model.eval()
|
| 77 |
+
return model
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# def get_large_model():
|
| 81 |
+
# model = CNHubertLarge()
|
| 82 |
+
# model.eval()
|
| 83 |
+
# return model
|
| 84 |
+
#
|
| 85 |
+
# def get_model_cvec():
|
| 86 |
+
# model = CVec()
|
| 87 |
+
# model.eval()
|
| 88 |
+
# return model
|
| 89 |
+
#
|
| 90 |
+
# def get_model_cnw2v2base():
|
| 91 |
+
# model = cnw2v2base()
|
| 92 |
+
# model.eval()
|
| 93 |
+
# return model
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_content(hmodel, wav_16k_tensor):
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
feats = hmodel(wav_16k_tensor)
|
| 99 |
+
return feats.transpose(1, 2)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if __name__ == "__main__":
|
| 103 |
+
model = get_model()
|
| 104 |
+
src_path = "/Users/Shared/原音频2.wav"
|
| 105 |
+
wav_16k_tensor = utils.load_wav_to_torch_and_resample(src_path, 16000)
|
| 106 |
+
model = model
|
| 107 |
+
wav_16k_tensor = wav_16k_tensor
|
| 108 |
+
feats = get_content(model, wav_16k_tensor)
|
| 109 |
+
print(feats.shape)
|
feature_extractor/whisper_enc.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_model():
|
| 5 |
+
import whisper
|
| 6 |
+
|
| 7 |
+
model = whisper.load_model("small", device="cpu")
|
| 8 |
+
|
| 9 |
+
return model.encoder
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_content(model=None, wav_16k_tensor=None):
|
| 13 |
+
from whisper import log_mel_spectrogram, pad_or_trim
|
| 14 |
+
|
| 15 |
+
dev = next(model.parameters()).device
|
| 16 |
+
mel = log_mel_spectrogram(wav_16k_tensor).to(dev)[:, :3000]
|
| 17 |
+
# if torch.cuda.is_available():
|
| 18 |
+
# mel = mel.to(torch.float16)
|
| 19 |
+
feature_len = mel.shape[-1] // 2
|
| 20 |
+
assert mel.shape[-1] < 3000, "输入音频过长,只允许输入30以内音频"
|
| 21 |
+
with torch.no_grad():
|
| 22 |
+
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[
|
| 23 |
+
:1, :feature_len, :
|
| 24 |
+
].transpose(1, 2)
|
| 25 |
+
return feature
|
inference_webui.py
ADDED
|
@@ -0,0 +1,867 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import traceback
|
| 5 |
+
from time import time as ttime
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import gradio.themes as themes
|
| 9 |
+
import librosa
|
| 10 |
+
import nltk
|
| 11 |
+
import numpy as np
|
| 12 |
+
import spaces
|
| 13 |
+
import torch
|
| 14 |
+
import torchaudio
|
| 15 |
+
from gradio.themes.utils import fonts
|
| 16 |
+
from huggingface_hub import snapshot_download
|
| 17 |
+
from transformers.models.auto.modeling_auto import AutoModelForMaskedLM
|
| 18 |
+
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
| 19 |
+
|
| 20 |
+
from AR.models.structs import T2SRequest
|
| 21 |
+
from AR.models.t2s_model_flash_attn import CUDAGraphRunner
|
| 22 |
+
from feature_extractor import cnhubert
|
| 23 |
+
from module.mel_processing import spectrogram_torch
|
| 24 |
+
from module.models import SynthesizerTrn
|
| 25 |
+
from sv import SV
|
| 26 |
+
from text import chinese, cleaned_text_to_sequence
|
| 27 |
+
from text.cleaner import clean_text
|
| 28 |
+
from text.LangSegmenter import LangSegmenter
|
| 29 |
+
from tools.i18n.i18n import I18nAuto
|
| 30 |
+
|
| 31 |
+
logging.getLogger("markdown_it").setLevel(logging.ERROR)
|
| 32 |
+
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
| 33 |
+
logging.getLogger("httpcore").setLevel(logging.ERROR)
|
| 34 |
+
logging.getLogger("httpx").setLevel(logging.ERROR)
|
| 35 |
+
logging.getLogger("asyncio").setLevel(logging.ERROR)
|
| 36 |
+
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
|
| 37 |
+
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
|
| 38 |
+
logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
|
| 39 |
+
logging.getLogger("python_multipart.multipart").setLevel(logging.ERROR)
|
| 40 |
+
logging.getLogger("split_lang.split.splitter").setLevel(logging.ERROR)
|
| 41 |
+
logging.getLogger("filelock").setLevel(logging.INFO)
|
| 42 |
+
|
| 43 |
+
os.makedirs("pretrained_models", exist_ok=True)
|
| 44 |
+
|
| 45 |
+
nltk.download("averaged_perceptron_tagger_eng")
|
| 46 |
+
|
| 47 |
+
snapshot_download(
|
| 48 |
+
repo_id="lj1995/GPT-SoVITS",
|
| 49 |
+
repo_type="model",
|
| 50 |
+
allow_patterns="chinese*",
|
| 51 |
+
local_dir="pretrained_models",
|
| 52 |
+
)
|
| 53 |
+
snapshot_download(
|
| 54 |
+
repo_id="lj1995/GPT-SoVITS",
|
| 55 |
+
repo_type="model",
|
| 56 |
+
allow_patterns="s1v3.ckpt",
|
| 57 |
+
local_dir="pretrained_models",
|
| 58 |
+
)
|
| 59 |
+
snapshot_download(
|
| 60 |
+
repo_id="lj1995/GPT-SoVITS",
|
| 61 |
+
repo_type="model",
|
| 62 |
+
allow_patterns="sv*",
|
| 63 |
+
local_dir="pretrained_models",
|
| 64 |
+
)
|
| 65 |
+
snapshot_download(
|
| 66 |
+
repo_id="lj1995/GPT-SoVITS",
|
| 67 |
+
repo_type="model",
|
| 68 |
+
allow_patterns="v2Pro/s2Gv2ProPlus.pth",
|
| 69 |
+
local_dir="pretrained_models",
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
version = "v2" # os.environ.get("version","v2")
|
| 73 |
+
cnhubert_base_path = os.environ.get("cnhubert_base_path", "pretrained_models/chinese-hubert-base")
|
| 74 |
+
bert_path = os.environ.get("bert_path", "pretrained_models/chinese-roberta-wwm-ext-large")
|
| 75 |
+
cnhubert.cnhubert_base_path = cnhubert_base_path
|
| 76 |
+
|
| 77 |
+
punctuation = set(["!", "?", "…", ",", ".", "-", " "])
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
i18n = I18nAuto(language="Auto")
|
| 81 |
+
|
| 82 |
+
if torch.cuda.is_available():
|
| 83 |
+
device = "cuda"
|
| 84 |
+
is_half = True
|
| 85 |
+
else:
|
| 86 |
+
device = "cpu"
|
| 87 |
+
is_half = False
|
| 88 |
+
|
| 89 |
+
dict_language_v1 = {
|
| 90 |
+
i18n("中文"): "all_zh", # 全部按中文识别
|
| 91 |
+
i18n("英文"): "en", # 全部按英文识别#######不变
|
| 92 |
+
i18n("日文"): "all_ja", # 全部按日文识别
|
| 93 |
+
i18n("中英混合"): "zh", # 按中英混合识别####不变
|
| 94 |
+
i18n("日英混合"): "ja", # 按日英混合识别####不变
|
| 95 |
+
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
|
| 96 |
+
}
|
| 97 |
+
dict_language_v2 = {
|
| 98 |
+
i18n("中文"): "all_zh", # 全部按中文识别
|
| 99 |
+
i18n("英文"): "en", # 全部按英文识别#######不变
|
| 100 |
+
i18n("日文"): "all_ja", # 全部按日文识别
|
| 101 |
+
i18n("粤语"): "all_yue", # 全部按中文识别
|
| 102 |
+
i18n("韩文"): "all_ko", # 全部按韩文识别
|
| 103 |
+
i18n("中英混合"): "zh", # 按中英混合识别####不变
|
| 104 |
+
i18n("日英混合"): "ja", # 按日英混合识别####不变
|
| 105 |
+
i18n("粤英混合"): "yue", # 按粤英混合识别####不变
|
| 106 |
+
i18n("韩英混合"): "ko", # 按韩英混合识别####不变
|
| 107 |
+
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
|
| 108 |
+
i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种
|
| 109 |
+
}
|
| 110 |
+
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
|
| 111 |
+
|
| 112 |
+
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
| 113 |
+
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
|
| 114 |
+
if is_half is True:
|
| 115 |
+
bert_model = bert_model.half().to(device)
|
| 116 |
+
else:
|
| 117 |
+
bert_model = bert_model.to(device)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def get_bert_feature(text, word2ph):
|
| 121 |
+
with torch.no_grad():
|
| 122 |
+
inputs = tokenizer(text, return_tensors="pt")
|
| 123 |
+
for i in inputs:
|
| 124 |
+
inputs[i] = inputs[i].to(device)
|
| 125 |
+
res = bert_model(**inputs, output_hidden_states=True)
|
| 126 |
+
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
| 127 |
+
assert len(word2ph) == len(text)
|
| 128 |
+
phone_level_feature = []
|
| 129 |
+
for i in range(len(word2ph)):
|
| 130 |
+
repeat_feature = res[i].repeat(word2ph[i], 1)
|
| 131 |
+
phone_level_feature.append(repeat_feature)
|
| 132 |
+
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
| 133 |
+
return phone_level_feature.T
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class DictToAttrRecursive(dict):
|
| 137 |
+
def __init__(self, input_dict):
|
| 138 |
+
super().__init__(input_dict)
|
| 139 |
+
for key, value in input_dict.items():
|
| 140 |
+
if isinstance(value, dict):
|
| 141 |
+
value = DictToAttrRecursive(value)
|
| 142 |
+
self[key] = value
|
| 143 |
+
setattr(self, key, value)
|
| 144 |
+
|
| 145 |
+
def __getattr__(self, item):
|
| 146 |
+
try:
|
| 147 |
+
return self[item]
|
| 148 |
+
except KeyError:
|
| 149 |
+
raise AttributeError(f"Attribute {item} not found")
|
| 150 |
+
|
| 151 |
+
def __setattr__(self, key, value):
|
| 152 |
+
if isinstance(value, dict):
|
| 153 |
+
value = DictToAttrRecursive(value)
|
| 154 |
+
super(DictToAttrRecursive, self).__setitem__(key, value)
|
| 155 |
+
super().__setattr__(key, value)
|
| 156 |
+
|
| 157 |
+
def __delattr__(self, item):
|
| 158 |
+
try:
|
| 159 |
+
del self[item]
|
| 160 |
+
except KeyError:
|
| 161 |
+
raise AttributeError(f"Attribute {item} not found")
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
ssl_model = cnhubert.get_model()
|
| 165 |
+
if is_half is True:
|
| 166 |
+
ssl_model = ssl_model.half().to(device)
|
| 167 |
+
else:
|
| 168 |
+
ssl_model = ssl_model.to(device)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
|
| 172 |
+
global vq_model, hps, version, dict_language
|
| 173 |
+
dict_s2 = torch.load(sovits_path, map_location="cpu")
|
| 174 |
+
hps = dict_s2["config"]
|
| 175 |
+
hps = DictToAttrRecursive(hps)
|
| 176 |
+
hps.model.semantic_frame_rate = "25hz"
|
| 177 |
+
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
| 178 |
+
hps.model.version = "v1"
|
| 179 |
+
else:
|
| 180 |
+
hps.model.version = "v2"
|
| 181 |
+
version = hps.model.version
|
| 182 |
+
# print("sovits版本:",hps.model.version)
|
| 183 |
+
vq_model = SynthesizerTrn(
|
| 184 |
+
hps.data.filter_length // 2 + 1,
|
| 185 |
+
hps.train.segment_size // hps.data.hop_length,
|
| 186 |
+
n_speakers=hps.data.n_speakers,
|
| 187 |
+
**hps.model,
|
| 188 |
+
)
|
| 189 |
+
if "pretrained" not in sovits_path:
|
| 190 |
+
del vq_model.enc_q
|
| 191 |
+
if is_half == True:
|
| 192 |
+
vq_model = vq_model.half().to(device)
|
| 193 |
+
else:
|
| 194 |
+
vq_model = vq_model.to(device)
|
| 195 |
+
vq_model.eval()
|
| 196 |
+
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
| 197 |
+
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
|
| 198 |
+
if prompt_language is not None and text_language is not None:
|
| 199 |
+
if prompt_language in list(dict_language.keys()):
|
| 200 |
+
prompt_text_update, prompt_language_update = (
|
| 201 |
+
{"__type__": "update"},
|
| 202 |
+
{"__type__": "update", "value": prompt_language},
|
| 203 |
+
)
|
| 204 |
+
else:
|
| 205 |
+
prompt_text_update = {"__type__": "update", "value": ""}
|
| 206 |
+
prompt_language_update = {"__type__": "update", "value": i18n("中文")}
|
| 207 |
+
if text_language in list(dict_language.keys()):
|
| 208 |
+
text_update, text_language_update = {"__type__": "update"}, {"__type__": "update", "value": text_language}
|
| 209 |
+
else:
|
| 210 |
+
text_update = {"__type__": "update", "value": ""}
|
| 211 |
+
text_language_update = {"__type__": "update", "value": i18n("中文")}
|
| 212 |
+
return (
|
| 213 |
+
{"__type__": "update", "choices": list(dict_language.keys())},
|
| 214 |
+
{"__type__": "update", "choices": list(dict_language.keys())},
|
| 215 |
+
prompt_text_update,
|
| 216 |
+
prompt_language_update,
|
| 217 |
+
text_update,
|
| 218 |
+
text_language_update,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
change_sovits_weights("pretrained_models/v2Pro/s2Gv2ProPlus.pth")
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def change_gpt_weights(gpt_path):
|
| 226 |
+
global t2s_model, config
|
| 227 |
+
dict_s1 = torch.load(gpt_path, map_location="cpu")
|
| 228 |
+
config = dict_s1["config"]
|
| 229 |
+
t2s_model = CUDAGraphRunner(
|
| 230 |
+
CUDAGraphRunner.load_decoder(gpt_path), torch.device(device), torch.float16 if is_half else torch.float32
|
| 231 |
+
)
|
| 232 |
+
total = sum(p.numel() for p in t2s_model.decoder_model.parameters())
|
| 233 |
+
print("Number of parameter: %.2fM" % (total / 1e6))
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
change_gpt_weights("pretrained_models/s1v3.ckpt")
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
sv_cn_model = SV(device, is_half)
|
| 240 |
+
|
| 241 |
+
resample_transform_dict = {}
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def resample(audio_tensor, sr0, sr1, device):
|
| 245 |
+
global resample_transform_dict
|
| 246 |
+
key = "%s-%s-%s" % (sr0, sr1, str(device))
|
| 247 |
+
if key not in resample_transform_dict:
|
| 248 |
+
resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
|
| 249 |
+
return resample_transform_dict[key](audio_tensor)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def get_spepc(hps, filename, dtype, device, is_v2pro=False):
|
| 253 |
+
sr1 = int(hps.data.sampling_rate)
|
| 254 |
+
audio, sr0 = torchaudio.load(filename)
|
| 255 |
+
if sr0 != sr1:
|
| 256 |
+
audio = audio.to(device)
|
| 257 |
+
if audio.shape[0] == 2:
|
| 258 |
+
audio = audio.mean(0).unsqueeze(0)
|
| 259 |
+
audio = resample(audio, sr0, sr1, device)
|
| 260 |
+
else:
|
| 261 |
+
audio = audio.to(device)
|
| 262 |
+
if audio.shape[0] == 2:
|
| 263 |
+
audio = audio.mean(0).unsqueeze(0)
|
| 264 |
+
|
| 265 |
+
maxx = audio.abs().max()
|
| 266 |
+
if maxx > 1:
|
| 267 |
+
audio /= min(2, maxx)
|
| 268 |
+
spec = spectrogram_torch(
|
| 269 |
+
audio,
|
| 270 |
+
hps.data.filter_length,
|
| 271 |
+
hps.data.sampling_rate,
|
| 272 |
+
hps.data.hop_length,
|
| 273 |
+
hps.data.win_length,
|
| 274 |
+
center=False,
|
| 275 |
+
)
|
| 276 |
+
spec = spec.to(dtype)
|
| 277 |
+
if is_v2pro is True:
|
| 278 |
+
audio = resample(audio, sr1, 16000, device).to(dtype)
|
| 279 |
+
return spec, audio
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def clean_text_inf(text, language, version):
|
| 283 |
+
language = language.replace("all_", "")
|
| 284 |
+
phones, word2ph, norm_text = clean_text(text, language, version)
|
| 285 |
+
phones = cleaned_text_to_sequence(phones, version)
|
| 286 |
+
return phones, word2ph, norm_text
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
dtype = torch.float16 if is_half is True else torch.float32
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def get_bert_inf(phones, word2ph, norm_text, language):
|
| 293 |
+
language = language.replace("all_", "")
|
| 294 |
+
if language == "zh":
|
| 295 |
+
bert = get_bert_feature(norm_text, word2ph).to(device) # .to(dtype)
|
| 296 |
+
else:
|
| 297 |
+
bert = torch.zeros(
|
| 298 |
+
(1024, len(phones)),
|
| 299 |
+
dtype=torch.float16 if is_half is True else torch.float32,
|
| 300 |
+
).to(device)
|
| 301 |
+
|
| 302 |
+
return bert
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…"}
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def get_first(text):
|
| 309 |
+
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
|
| 310 |
+
text = re.split(pattern, text)[0].strip()
|
| 311 |
+
return text
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def get_phones_and_bert(text, language, version, final=False):
|
| 315 |
+
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
| 316 |
+
formattext = text
|
| 317 |
+
while " " in formattext:
|
| 318 |
+
formattext = formattext.replace(" ", " ")
|
| 319 |
+
if language == "all_zh":
|
| 320 |
+
if re.search(r"[A-Za-z]", formattext):
|
| 321 |
+
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
|
| 322 |
+
formattext = chinese.mix_text_normalize(formattext)
|
| 323 |
+
return get_phones_and_bert(formattext, "zh", version)
|
| 324 |
+
else:
|
| 325 |
+
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
|
| 326 |
+
bert = get_bert_feature(norm_text, word2ph).to(device)
|
| 327 |
+
elif language == "all_yue" and re.search(r"[A-Za-z]", formattext):
|
| 328 |
+
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
|
| 329 |
+
formattext = chinese.mix_text_normalize(formattext)
|
| 330 |
+
return get_phones_and_bert(formattext, "yue", version)
|
| 331 |
+
else:
|
| 332 |
+
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
|
| 333 |
+
bert = torch.zeros(
|
| 334 |
+
(1024, len(phones)),
|
| 335 |
+
dtype=torch.float16 if is_half is True else torch.float32,
|
| 336 |
+
).to(device)
|
| 337 |
+
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
|
| 338 |
+
textlist = []
|
| 339 |
+
langlist = []
|
| 340 |
+
if language == "auto":
|
| 341 |
+
for tmp in LangSegmenter.getTexts(text):
|
| 342 |
+
langlist.append(tmp["lang"])
|
| 343 |
+
textlist.append(tmp["text"])
|
| 344 |
+
elif language == "auto_yue":
|
| 345 |
+
for tmp in LangSegmenter.getTexts(text):
|
| 346 |
+
if tmp["lang"] == "zh":
|
| 347 |
+
tmp["lang"] = "yue"
|
| 348 |
+
langlist.append(tmp["lang"])
|
| 349 |
+
textlist.append(tmp["text"])
|
| 350 |
+
else:
|
| 351 |
+
for tmp in LangSegmenter.getTexts(text):
|
| 352 |
+
if tmp["lang"] == "en":
|
| 353 |
+
langlist.append(tmp["lang"])
|
| 354 |
+
else:
|
| 355 |
+
# 因无法区别中日韩文汉字,以用户输入为准
|
| 356 |
+
langlist.append(language)
|
| 357 |
+
textlist.append(tmp["text"])
|
| 358 |
+
print(textlist)
|
| 359 |
+
print(langlist)
|
| 360 |
+
phones_list = []
|
| 361 |
+
bert_list = []
|
| 362 |
+
norm_text_list = []
|
| 363 |
+
for i in range(len(textlist)):
|
| 364 |
+
lang = langlist[i]
|
| 365 |
+
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
|
| 366 |
+
bert = get_bert_inf(phones, word2ph, norm_text, lang)
|
| 367 |
+
phones_list.append(phones)
|
| 368 |
+
norm_text_list.append(norm_text)
|
| 369 |
+
bert_list.append(bert)
|
| 370 |
+
bert = torch.cat(bert_list, dim=1)
|
| 371 |
+
phones = sum(phones_list, [])
|
| 372 |
+
norm_text = "".join(norm_text_list)
|
| 373 |
+
|
| 374 |
+
if not final and len(phones) < 6:
|
| 375 |
+
return get_phones_and_bert("." + text, language, version, final=True)
|
| 376 |
+
|
| 377 |
+
return phones, bert.to(dtype), norm_text
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def merge_short_text_in_array(texts, threshold):
|
| 381 |
+
if (len(texts)) < 2:
|
| 382 |
+
return texts
|
| 383 |
+
result = []
|
| 384 |
+
text = ""
|
| 385 |
+
for ele in texts:
|
| 386 |
+
text += ele
|
| 387 |
+
if len(text) >= threshold:
|
| 388 |
+
result.append(text)
|
| 389 |
+
text = ""
|
| 390 |
+
if len(text) > 0:
|
| 391 |
+
if len(result) == 0:
|
| 392 |
+
result.append(text)
|
| 393 |
+
else:
|
| 394 |
+
result[len(result) - 1] += text
|
| 395 |
+
return result
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature
|
| 399 |
+
# cache_tokens={}#暂未实现清理机制
|
| 400 |
+
cache = {}
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
@spaces.GPU
|
| 404 |
+
def get_tts_wav(
|
| 405 |
+
ref_wav_path,
|
| 406 |
+
prompt_text,
|
| 407 |
+
prompt_language,
|
| 408 |
+
text,
|
| 409 |
+
text_language,
|
| 410 |
+
how_to_cut=i18n("不切"),
|
| 411 |
+
top_k=20,
|
| 412 |
+
top_p=0.6,
|
| 413 |
+
temperature=0.6,
|
| 414 |
+
ref_free=False,
|
| 415 |
+
speed=1,
|
| 416 |
+
if_freeze=False,
|
| 417 |
+
inp_refs=123,
|
| 418 |
+
):
|
| 419 |
+
global cache
|
| 420 |
+
if ref_wav_path:
|
| 421 |
+
pass
|
| 422 |
+
else:
|
| 423 |
+
gr.Warning(i18n("请上传参考音频"))
|
| 424 |
+
if text:
|
| 425 |
+
pass
|
| 426 |
+
else:
|
| 427 |
+
gr.Warning(i18n("请填入推理文本"))
|
| 428 |
+
t = []
|
| 429 |
+
if prompt_text is None or len(prompt_text) == 0:
|
| 430 |
+
ref_free = True
|
| 431 |
+
t0 = ttime()
|
| 432 |
+
prompt_language = dict_language[prompt_language]
|
| 433 |
+
text_language = dict_language[text_language]
|
| 434 |
+
|
| 435 |
+
if not ref_free:
|
| 436 |
+
prompt_text = prompt_text.strip("\n")
|
| 437 |
+
if prompt_text[-1] not in splits:
|
| 438 |
+
prompt_text += "。" if prompt_language != "en" else "."
|
| 439 |
+
print(i18n("实际输入的参考文本:"), prompt_text)
|
| 440 |
+
text = text.strip("\n")
|
| 441 |
+
if text[0] not in splits and len(get_first(text)) < 4:
|
| 442 |
+
text = "。" + text if text_language != "en" else "." + text
|
| 443 |
+
|
| 444 |
+
print(i18n("实际输入的目标文本:"), text)
|
| 445 |
+
zero_wav = np.zeros(
|
| 446 |
+
int(hps.data.sampling_rate * 0.3),
|
| 447 |
+
dtype=np.float16 if is_half is True else np.float32,
|
| 448 |
+
)
|
| 449 |
+
if not ref_free:
|
| 450 |
+
with torch.no_grad():
|
| 451 |
+
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
| 452 |
+
if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000:
|
| 453 |
+
gr.Warning(i18n("参考音频在3~10秒范围外,请更换!"))
|
| 454 |
+
raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
|
| 455 |
+
wav16k = torch.from_numpy(wav16k)
|
| 456 |
+
zero_wav_torch = torch.from_numpy(zero_wav)
|
| 457 |
+
if is_half is True:
|
| 458 |
+
wav16k = wav16k.half().to(device)
|
| 459 |
+
zero_wav_torch = zero_wav_torch.half().to(device)
|
| 460 |
+
else:
|
| 461 |
+
wav16k = wav16k.to(device)
|
| 462 |
+
zero_wav_torch = zero_wav_torch.to(device)
|
| 463 |
+
wav16k = torch.cat([wav16k, zero_wav_torch])
|
| 464 |
+
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
| 465 |
+
codes = vq_model.extract_latent(ssl_content)
|
| 466 |
+
prompt_semantic = codes[0, 0]
|
| 467 |
+
prompt = prompt_semantic.unsqueeze(0).to(device)
|
| 468 |
+
|
| 469 |
+
t1 = ttime()
|
| 470 |
+
t.append(t1 - t0)
|
| 471 |
+
|
| 472 |
+
if how_to_cut == i18n("凑四句一切"):
|
| 473 |
+
text = cut1(text)
|
| 474 |
+
elif how_to_cut == i18n("凑50字一切"):
|
| 475 |
+
text = cut2(text)
|
| 476 |
+
elif how_to_cut == i18n("按中文句号。切"):
|
| 477 |
+
text = cut3(text)
|
| 478 |
+
elif how_to_cut == i18n("按英文句号.切"):
|
| 479 |
+
text = cut4(text)
|
| 480 |
+
elif how_to_cut == i18n("按标点符号切"):
|
| 481 |
+
text = cut5(text)
|
| 482 |
+
while "\n\n" in text:
|
| 483 |
+
text = text.replace("\n\n", "\n")
|
| 484 |
+
print(i18n("实际输入的目标文本(切句后):"), text)
|
| 485 |
+
texts = text.split("\n")
|
| 486 |
+
texts = process_text(texts)
|
| 487 |
+
texts = merge_short_text_in_array(texts, 5)
|
| 488 |
+
audio_opt = []
|
| 489 |
+
if not ref_free:
|
| 490 |
+
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
|
| 491 |
+
|
| 492 |
+
infer_speed: list[float] = []
|
| 493 |
+
|
| 494 |
+
for i_text, text in enumerate(texts):
|
| 495 |
+
# 解决输入目标文本的空行导致报错的问题
|
| 496 |
+
if len(text.strip()) == 0:
|
| 497 |
+
continue
|
| 498 |
+
if text[-1] not in splits:
|
| 499 |
+
text += "。" if text_language != "en" else "."
|
| 500 |
+
print(i18n("实际输入的目标文本(每句):"), text)
|
| 501 |
+
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
|
| 502 |
+
print(i18n("前端处理后的文本(每句):"), norm_text2)
|
| 503 |
+
if not ref_free:
|
| 504 |
+
bert = torch.cat([bert1, bert2], 1)
|
| 505 |
+
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
| 506 |
+
else:
|
| 507 |
+
bert = bert2
|
| 508 |
+
all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
| 509 |
+
|
| 510 |
+
bert = bert.to(device).unsqueeze(0)
|
| 511 |
+
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
| 512 |
+
|
| 513 |
+
t2 = ttime()
|
| 514 |
+
# cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature)
|
| 515 |
+
# print(cache.keys(),if_freeze)
|
| 516 |
+
if i_text in cache and if_freeze is True:
|
| 517 |
+
pred_semantic = cache[i_text]
|
| 518 |
+
else:
|
| 519 |
+
with torch.no_grad():
|
| 520 |
+
t2s_request = T2SRequest(
|
| 521 |
+
[all_phoneme_ids.squeeze(0)],
|
| 522 |
+
all_phoneme_len,
|
| 523 |
+
all_phoneme_ids.new_zeros((1, 0)) if ref_free else prompt,
|
| 524 |
+
[bert.squeeze(0)],
|
| 525 |
+
valid_length=1,
|
| 526 |
+
top_k=top_k,
|
| 527 |
+
top_p=top_p,
|
| 528 |
+
temperature=temperature,
|
| 529 |
+
early_stop_num=1500,
|
| 530 |
+
use_cuda_graph=True,
|
| 531 |
+
# debug=True,
|
| 532 |
+
)
|
| 533 |
+
t2s_result = t2s_model.generate(t2s_request)
|
| 534 |
+
|
| 535 |
+
if t2s_result.exception is not None:
|
| 536 |
+
print(t2s_result.traceback)
|
| 537 |
+
raise t2s_result.exception
|
| 538 |
+
|
| 539 |
+
infer_speed.append(t2s_result.infer_speed)
|
| 540 |
+
pred_semantic = t2s_result.result
|
| 541 |
+
assert pred_semantic
|
| 542 |
+
cache[i_text] = pred_semantic
|
| 543 |
+
t3 = ttime()
|
| 544 |
+
refers = []
|
| 545 |
+
sv_emb = []
|
| 546 |
+
if inp_refs:
|
| 547 |
+
for path in inp_refs:
|
| 548 |
+
try:
|
| 549 |
+
refer, audio_tensor = get_spepc(hps, path.name, dtype, device, is_v2pro=True)
|
| 550 |
+
refers.append(refer)
|
| 551 |
+
sv_emb.append(sv_cn_model.compute_embedding3(audio_tensor))
|
| 552 |
+
except:
|
| 553 |
+
traceback.print_exc()
|
| 554 |
+
if len(refers) == 0:
|
| 555 |
+
refers, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device, is_v2pro=True)
|
| 556 |
+
refers = [refers]
|
| 557 |
+
sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)]
|
| 558 |
+
audio = (
|
| 559 |
+
vq_model.decode(
|
| 560 |
+
pred_semantic[0].unsqueeze(0).unsqueeze(0),
|
| 561 |
+
torch.LongTensor(phones2).to(device).unsqueeze(0),
|
| 562 |
+
refers,
|
| 563 |
+
speed=speed,
|
| 564 |
+
sv_emb=sv_emb,
|
| 565 |
+
)
|
| 566 |
+
.detach()
|
| 567 |
+
.cpu()
|
| 568 |
+
.numpy()[0][0]
|
| 569 |
+
)
|
| 570 |
+
max_audio = np.abs(audio).max() # 简单防止16bit爆音
|
| 571 |
+
if max_audio > 1:
|
| 572 |
+
audio /= max_audio
|
| 573 |
+
audio_opt.append(audio)
|
| 574 |
+
audio_opt.append(zero_wav)
|
| 575 |
+
t4 = ttime()
|
| 576 |
+
t.extend([t2 - t1, t3 - t2, t4 - t3])
|
| 577 |
+
t1 = ttime()
|
| 578 |
+
print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
|
| 579 |
+
gr.Info(f"{sum(infer_speed) / len(infer_speed):.2f} Token/s", title="Infer Speed")
|
| 580 |
+
gr.Info("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])), title="Time Stamps")
|
| 581 |
+
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
def split(todo_text):
|
| 585 |
+
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
| 586 |
+
if todo_text[-1] not in splits:
|
| 587 |
+
todo_text += "。"
|
| 588 |
+
i_split_head = i_split_tail = 0
|
| 589 |
+
len_text = len(todo_text)
|
| 590 |
+
todo_texts = []
|
| 591 |
+
while 1:
|
| 592 |
+
if i_split_head >= len_text:
|
| 593 |
+
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
|
| 594 |
+
if todo_text[i_split_head] in splits:
|
| 595 |
+
i_split_head += 1
|
| 596 |
+
todo_texts.append(todo_text[i_split_tail:i_split_head])
|
| 597 |
+
i_split_tail = i_split_head
|
| 598 |
+
else:
|
| 599 |
+
i_split_head += 1
|
| 600 |
+
return todo_texts
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def cut1(inp):
|
| 604 |
+
inp = inp.strip("\n")
|
| 605 |
+
inps = split(inp)
|
| 606 |
+
split_idx = list(range(0, len(inps), 4))
|
| 607 |
+
split_idx[-1] = None
|
| 608 |
+
if len(split_idx) > 1:
|
| 609 |
+
opts = []
|
| 610 |
+
for idx in range(len(split_idx) - 1):
|
| 611 |
+
opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
|
| 612 |
+
else:
|
| 613 |
+
opts = [inp]
|
| 614 |
+
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
| 615 |
+
return "\n".join(opts)
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
def cut2(inp):
|
| 619 |
+
inp = inp.strip("\n")
|
| 620 |
+
inps = split(inp)
|
| 621 |
+
if len(inps) < 2:
|
| 622 |
+
return inp
|
| 623 |
+
opts = []
|
| 624 |
+
summ = 0
|
| 625 |
+
tmp_str = ""
|
| 626 |
+
for i in range(len(inps)):
|
| 627 |
+
summ += len(inps[i])
|
| 628 |
+
tmp_str += inps[i]
|
| 629 |
+
if summ > 50:
|
| 630 |
+
summ = 0
|
| 631 |
+
opts.append(tmp_str)
|
| 632 |
+
tmp_str = ""
|
| 633 |
+
if tmp_str != "":
|
| 634 |
+
opts.append(tmp_str)
|
| 635 |
+
# print(opts)
|
| 636 |
+
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
|
| 637 |
+
opts[-2] = opts[-2] + opts[-1]
|
| 638 |
+
opts = opts[:-1]
|
| 639 |
+
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
| 640 |
+
return "\n".join(opts)
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
def cut3(inp):
|
| 644 |
+
inp = inp.strip("\n")
|
| 645 |
+
opts = ["%s" % item for item in inp.strip("。").split("。")]
|
| 646 |
+
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
| 647 |
+
return "\n".join(opts)
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
def cut4(inp):
|
| 651 |
+
inp = inp.strip("\n")
|
| 652 |
+
opts = ["%s" % item for item in inp.strip(".").split(".")]
|
| 653 |
+
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
| 654 |
+
return "\n".join(opts)
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
|
| 658 |
+
def cut5(inp):
|
| 659 |
+
inp = inp.strip("\n")
|
| 660 |
+
punds = {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}
|
| 661 |
+
mergeitems = []
|
| 662 |
+
items = []
|
| 663 |
+
|
| 664 |
+
for i, char in enumerate(inp):
|
| 665 |
+
if char in punds:
|
| 666 |
+
if char == "." and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
|
| 667 |
+
items.append(char)
|
| 668 |
+
else:
|
| 669 |
+
items.append(char)
|
| 670 |
+
mergeitems.append("".join(items))
|
| 671 |
+
items = []
|
| 672 |
+
else:
|
| 673 |
+
items.append(char)
|
| 674 |
+
|
| 675 |
+
if items:
|
| 676 |
+
mergeitems.append("".join(items))
|
| 677 |
+
|
| 678 |
+
opt = [item for item in mergeitems if not set(item).issubset(punds)]
|
| 679 |
+
return "\n".join(opt)
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
def custom_sort_key(s):
|
| 683 |
+
# 使用正则表达式提取字符串中的数字部分和非数字部分
|
| 684 |
+
parts = re.split(r"(\d+)", s)
|
| 685 |
+
# 将数字部分转换为整数,非数字部分保持不变
|
| 686 |
+
parts = [int(part) if part.isdigit() else part for part in parts]
|
| 687 |
+
return parts
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
def process_text(texts):
|
| 691 |
+
_text = []
|
| 692 |
+
if all(text in [None, " ", "\n", ""] for text in texts):
|
| 693 |
+
raise ValueError(i18n("请输入有效文本"))
|
| 694 |
+
for text in texts:
|
| 695 |
+
if text in [None, " ", ""]:
|
| 696 |
+
pass
|
| 697 |
+
else:
|
| 698 |
+
_text.append(text)
|
| 699 |
+
return _text
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
def html_center(text, label="p"):
|
| 703 |
+
return f"""<div style="text-align: center; margin: 100; padding: 50;">
|
| 704 |
+
<{label} style="margin: 0; padding: 0;">{text}</{label}>
|
| 705 |
+
</div>"""
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
def html_left(text, label="p"):
|
| 709 |
+
return f"""<div style="text-align: left; margin: 0; padding: 0;">
|
| 710 |
+
<{label} style="margin: 0; padding: 0;">{text}</{label}>
|
| 711 |
+
</div>"""
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
theme = themes.Soft(
|
| 715 |
+
font=(
|
| 716 |
+
"-apple-system",
|
| 717 |
+
fonts.GoogleFont("Inter"),
|
| 718 |
+
fonts.GoogleFont("Quicksand"),
|
| 719 |
+
"ui-sans-serif",
|
| 720 |
+
"sans-serif",
|
| 721 |
+
)
|
| 722 |
+
)
|
| 723 |
+
theme.block_border_width = "1px"
|
| 724 |
+
|
| 725 |
+
with gr.Blocks(
|
| 726 |
+
title="GPT-SoVITS WebUI",
|
| 727 |
+
theme=theme,
|
| 728 |
+
analytics_enabled=False,
|
| 729 |
+
) as app:
|
| 730 |
+
gr.Markdown(
|
| 731 |
+
value="""# GPT-SoVITS-ProPlus Zero-shot TTS Demo
|
| 732 |
+
## https://github.com/RVC-Boss/GPT-SoVITS
|
| 733 |
+
Input 3 to 10s reference audio to guide the time-bre, speed, emotion of voice, and generate the speech you want by input the inference text. <br>
|
| 734 |
+
输入3至10秒的参考音频来引导待合成语音的音色、语速和情感,然后输入待合成目标文本,生成目标语音. <br>
|
| 735 |
+
Cross-lingual Support: Inference in languages different from the training dataset, currently supporting English, Japanese, Korean and Cantonese.<br>
|
| 736 |
+
目前支持中日英韩粤跨语种合成。<br>
|
| 737 |
+
This demo is open source under the MIT license. The author does not have any control over it. Users who use the software and distribute the sounds exported by the software are solely responsible. If you do not agree with this clause, you cannot use or reference any codes and files within this demo. <br>
|
| 738 |
+
本demo以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. 如不认可该条款, 则不能使用或引用该demo内的任何代码和文件.
|
| 739 |
+
"""
|
| 740 |
+
)
|
| 741 |
+
gr.Markdown(html_center(i18n("*请上传并填写参考信息"), "h3"))
|
| 742 |
+
with gr.Row(equal_height=True):
|
| 743 |
+
inp_ref = gr.Audio(label=i18n("请上传3~10秒内参考音频,超过会报错!"), type="filepath")
|
| 744 |
+
with gr.Column():
|
| 745 |
+
ref_text_free = gr.Checkbox(
|
| 746 |
+
label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"),
|
| 747 |
+
value=False,
|
| 748 |
+
interactive=True,
|
| 749 |
+
show_label=True,
|
| 750 |
+
)
|
| 751 |
+
prompt_text = gr.Textbox(
|
| 752 |
+
label=i18n("参考音频的文本"),
|
| 753 |
+
value="",
|
| 754 |
+
lines=3,
|
| 755 |
+
max_lines=3,
|
| 756 |
+
info=i18n(
|
| 757 |
+
"使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开。<br>开启后无视填写的参考文本。"
|
| 758 |
+
),
|
| 759 |
+
)
|
| 760 |
+
prompt_language = gr.Dropdown(
|
| 761 |
+
label=i18n("参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
|
| 762 |
+
)
|
| 763 |
+
inp_refs = gr.File(
|
| 764 |
+
label=i18n(
|
| 765 |
+
"可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。"
|
| 766 |
+
),
|
| 767 |
+
file_count="multiple",
|
| 768 |
+
)
|
| 769 |
+
gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"), "h3"))
|
| 770 |
+
with gr.Row(equal_height=True):
|
| 771 |
+
with gr.Column():
|
| 772 |
+
text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=26, max_lines=26)
|
| 773 |
+
with gr.Column():
|
| 774 |
+
text_language = gr.Dropdown(
|
| 775 |
+
label=i18n("需要合成的语种") + i18n(".限制范围越小判别效果越好。"),
|
| 776 |
+
choices=list(dict_language.keys()),
|
| 777 |
+
value=i18n("中文"),
|
| 778 |
+
)
|
| 779 |
+
how_to_cut = gr.Dropdown(
|
| 780 |
+
label=i18n("怎么切"),
|
| 781 |
+
choices=[
|
| 782 |
+
i18n("不切"),
|
| 783 |
+
i18n("凑四句一切"),
|
| 784 |
+
i18n("凑50字一切"),
|
| 785 |
+
i18n("按中文句号。切"),
|
| 786 |
+
i18n("按英文句号.切"),
|
| 787 |
+
i18n("按标点符号切"),
|
| 788 |
+
],
|
| 789 |
+
value=i18n("凑四句一切"),
|
| 790 |
+
interactive=True,
|
| 791 |
+
)
|
| 792 |
+
gr.Markdown(value=html_center(i18n("语速调整,高为更快")))
|
| 793 |
+
if_freeze = gr.Checkbox(
|
| 794 |
+
label=i18n("是否直接对上次合成结果调整语速和音色。防止随机性。"),
|
| 795 |
+
value=False,
|
| 796 |
+
interactive=True,
|
| 797 |
+
show_label=True,
|
| 798 |
+
)
|
| 799 |
+
speed = gr.Slider(minimum=0.6, maximum=1.65, step=0.05, label=i18n("语速"), value=1, interactive=True)
|
| 800 |
+
gr.Markdown(html_center(i18n("GPT采样参数(无参考文本时不要太低。不懂就用默认):")))
|
| 801 |
+
top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=15, interactive=True)
|
| 802 |
+
top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True)
|
| 803 |
+
temperature = gr.Slider(
|
| 804 |
+
minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True
|
| 805 |
+
)
|
| 806 |
+
with gr.Row(equal_height=True):
|
| 807 |
+
inference_button = gr.Button(i18n("合成语音"), variant="primary", size="lg")
|
| 808 |
+
output = gr.Audio(label=i18n("输出的语音"))
|
| 809 |
+
|
| 810 |
+
inference_button.click(
|
| 811 |
+
get_tts_wav,
|
| 812 |
+
[
|
| 813 |
+
inp_ref,
|
| 814 |
+
prompt_text,
|
| 815 |
+
prompt_language,
|
| 816 |
+
text,
|
| 817 |
+
text_language,
|
| 818 |
+
how_to_cut,
|
| 819 |
+
top_k,
|
| 820 |
+
top_p,
|
| 821 |
+
temperature,
|
| 822 |
+
ref_text_free,
|
| 823 |
+
speed,
|
| 824 |
+
if_freeze,
|
| 825 |
+
inp_refs,
|
| 826 |
+
],
|
| 827 |
+
[output],
|
| 828 |
+
)
|
| 829 |
+
|
| 830 |
+
if __name__ == "__main__":
|
| 831 |
+
import tempfile
|
| 832 |
+
import wave
|
| 833 |
+
|
| 834 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_file:
|
| 835 |
+
file_name = temp_file.name
|
| 836 |
+
with wave.open(temp_file, "w") as wav_file:
|
| 837 |
+
channels = 1
|
| 838 |
+
sample_width = 2
|
| 839 |
+
sample_rate = 44100
|
| 840 |
+
duration = 5
|
| 841 |
+
frequency = 440.0
|
| 842 |
+
|
| 843 |
+
t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
|
| 844 |
+
sine_wave = np.sin(2 * np.pi * frequency * t) # Sine Wave
|
| 845 |
+
int_wave = (sine_wave * 32767).astype(np.int16)
|
| 846 |
+
|
| 847 |
+
wav_file.setnchannels(channels) # pylint: disable=no-member
|
| 848 |
+
wav_file.setsampwidth(sample_width) # pylint: disable=no-member
|
| 849 |
+
wav_file.setframerate(sample_rate) # pylint: disable=no-member
|
| 850 |
+
wav_file.writeframes(int_wave.tobytes()) # pylint: disable=no-member
|
| 851 |
+
|
| 852 |
+
gen = get_tts_wav(
|
| 853 |
+
ref_wav_path=file_name,
|
| 854 |
+
prompt_text="",
|
| 855 |
+
prompt_language=i18n("中文"),
|
| 856 |
+
text="犯大吴疆土者,盛必击而破之,犯大吴疆土者,盛必击而破之,犯大吴疆土者,盛必击而破之,犯大吴疆土者,盛必击而破之",
|
| 857 |
+
text_language=i18n("中文"),
|
| 858 |
+
inp_refs=[],
|
| 859 |
+
)
|
| 860 |
+
next(gen)
|
| 861 |
+
|
| 862 |
+
app.queue().launch(
|
| 863 |
+
server_name="0.0.0.0",
|
| 864 |
+
inbrowser=True,
|
| 865 |
+
show_api=False,
|
| 866 |
+
allowed_paths=["/"],
|
| 867 |
+
)
|
module/__init__.py
ADDED
|
File without changes
|
module/attentions.py
ADDED
|
@@ -0,0 +1,709 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
|
| 6 |
+
from module import commons
|
| 7 |
+
from module.modules import LayerNorm
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Encoder(nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
hidden_channels,
|
| 14 |
+
filter_channels,
|
| 15 |
+
n_heads,
|
| 16 |
+
n_layers,
|
| 17 |
+
kernel_size=1,
|
| 18 |
+
p_dropout=0.0,
|
| 19 |
+
window_size=4,
|
| 20 |
+
isflow=False,
|
| 21 |
+
**kwargs
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.hidden_channels = hidden_channels
|
| 25 |
+
self.filter_channels = filter_channels
|
| 26 |
+
self.n_heads = n_heads
|
| 27 |
+
self.n_layers = n_layers
|
| 28 |
+
self.kernel_size = kernel_size
|
| 29 |
+
self.p_dropout = p_dropout
|
| 30 |
+
self.window_size = window_size
|
| 31 |
+
|
| 32 |
+
self.drop = nn.Dropout(p_dropout)
|
| 33 |
+
self.attn_layers = nn.ModuleList()
|
| 34 |
+
self.norm_layers_1 = nn.ModuleList()
|
| 35 |
+
self.ffn_layers = nn.ModuleList()
|
| 36 |
+
self.norm_layers_2 = nn.ModuleList()
|
| 37 |
+
for i in range(self.n_layers):
|
| 38 |
+
self.attn_layers.append(
|
| 39 |
+
MultiHeadAttention(
|
| 40 |
+
hidden_channels,
|
| 41 |
+
hidden_channels,
|
| 42 |
+
n_heads,
|
| 43 |
+
p_dropout=p_dropout,
|
| 44 |
+
window_size=window_size,
|
| 45 |
+
)
|
| 46 |
+
)
|
| 47 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
| 48 |
+
self.ffn_layers.append(
|
| 49 |
+
FFN(
|
| 50 |
+
hidden_channels,
|
| 51 |
+
hidden_channels,
|
| 52 |
+
filter_channels,
|
| 53 |
+
kernel_size,
|
| 54 |
+
p_dropout=p_dropout,
|
| 55 |
+
)
|
| 56 |
+
)
|
| 57 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
| 58 |
+
if isflow:
|
| 59 |
+
cond_layer = torch.nn.Conv1d(
|
| 60 |
+
kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
|
| 61 |
+
)
|
| 62 |
+
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
|
| 63 |
+
self.cond_layer = weight_norm_modules(cond_layer, name="weight")
|
| 64 |
+
self.gin_channels = kwargs["gin_channels"]
|
| 65 |
+
|
| 66 |
+
def forward(self, x, x_mask, g=None):
|
| 67 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
| 68 |
+
x = x * x_mask
|
| 69 |
+
if g is not None:
|
| 70 |
+
g = self.cond_layer(g)
|
| 71 |
+
|
| 72 |
+
for i in range(self.n_layers):
|
| 73 |
+
if g is not None:
|
| 74 |
+
x = self.cond_pre(x)
|
| 75 |
+
cond_offset = i * 2 * self.hidden_channels
|
| 76 |
+
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
| 77 |
+
x = commons.fused_add_tanh_sigmoid_multiply(
|
| 78 |
+
x, g_l, torch.IntTensor([self.hidden_channels])
|
| 79 |
+
)
|
| 80 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
| 81 |
+
y = self.drop(y)
|
| 82 |
+
x = self.norm_layers_1[i](x + y)
|
| 83 |
+
|
| 84 |
+
y = self.ffn_layers[i](x, x_mask)
|
| 85 |
+
y = self.drop(y)
|
| 86 |
+
x = self.norm_layers_2[i](x + y)
|
| 87 |
+
x = x * x_mask
|
| 88 |
+
return x
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class Decoder(nn.Module):
|
| 92 |
+
def __init__(
|
| 93 |
+
self,
|
| 94 |
+
hidden_channels,
|
| 95 |
+
filter_channels,
|
| 96 |
+
n_heads,
|
| 97 |
+
n_layers,
|
| 98 |
+
kernel_size=1,
|
| 99 |
+
p_dropout=0.0,
|
| 100 |
+
proximal_bias=False,
|
| 101 |
+
proximal_init=True,
|
| 102 |
+
**kwargs
|
| 103 |
+
):
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.hidden_channels = hidden_channels
|
| 106 |
+
self.filter_channels = filter_channels
|
| 107 |
+
self.n_heads = n_heads
|
| 108 |
+
self.n_layers = n_layers
|
| 109 |
+
self.kernel_size = kernel_size
|
| 110 |
+
self.p_dropout = p_dropout
|
| 111 |
+
self.proximal_bias = proximal_bias
|
| 112 |
+
self.proximal_init = proximal_init
|
| 113 |
+
|
| 114 |
+
self.drop = nn.Dropout(p_dropout)
|
| 115 |
+
self.self_attn_layers = nn.ModuleList()
|
| 116 |
+
self.norm_layers_0 = nn.ModuleList()
|
| 117 |
+
self.encdec_attn_layers = nn.ModuleList()
|
| 118 |
+
self.norm_layers_1 = nn.ModuleList()
|
| 119 |
+
self.ffn_layers = nn.ModuleList()
|
| 120 |
+
self.norm_layers_2 = nn.ModuleList()
|
| 121 |
+
for i in range(self.n_layers):
|
| 122 |
+
self.self_attn_layers.append(
|
| 123 |
+
MultiHeadAttention(
|
| 124 |
+
hidden_channels,
|
| 125 |
+
hidden_channels,
|
| 126 |
+
n_heads,
|
| 127 |
+
p_dropout=p_dropout,
|
| 128 |
+
proximal_bias=proximal_bias,
|
| 129 |
+
proximal_init=proximal_init,
|
| 130 |
+
)
|
| 131 |
+
)
|
| 132 |
+
self.norm_layers_0.append(LayerNorm(hidden_channels))
|
| 133 |
+
self.encdec_attn_layers.append(
|
| 134 |
+
MultiHeadAttention(
|
| 135 |
+
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
|
| 136 |
+
)
|
| 137 |
+
)
|
| 138 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
| 139 |
+
self.ffn_layers.append(
|
| 140 |
+
FFN(
|
| 141 |
+
hidden_channels,
|
| 142 |
+
hidden_channels,
|
| 143 |
+
filter_channels,
|
| 144 |
+
kernel_size,
|
| 145 |
+
p_dropout=p_dropout,
|
| 146 |
+
causal=True,
|
| 147 |
+
)
|
| 148 |
+
)
|
| 149 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
| 150 |
+
|
| 151 |
+
def forward(self, x, x_mask, h, h_mask):
|
| 152 |
+
"""
|
| 153 |
+
x: decoder input
|
| 154 |
+
h: encoder output
|
| 155 |
+
"""
|
| 156 |
+
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
|
| 157 |
+
device=x.device, dtype=x.dtype
|
| 158 |
+
)
|
| 159 |
+
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
| 160 |
+
x = x * x_mask
|
| 161 |
+
for i in range(self.n_layers):
|
| 162 |
+
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
| 163 |
+
y = self.drop(y)
|
| 164 |
+
x = self.norm_layers_0[i](x + y)
|
| 165 |
+
|
| 166 |
+
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
|
| 167 |
+
y = self.drop(y)
|
| 168 |
+
x = self.norm_layers_1[i](x + y)
|
| 169 |
+
|
| 170 |
+
y = self.ffn_layers[i](x, x_mask)
|
| 171 |
+
y = self.drop(y)
|
| 172 |
+
x = self.norm_layers_2[i](x + y)
|
| 173 |
+
x = x * x_mask
|
| 174 |
+
return x
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class MultiHeadAttention(nn.Module):
|
| 178 |
+
def __init__(
|
| 179 |
+
self,
|
| 180 |
+
channels,
|
| 181 |
+
out_channels,
|
| 182 |
+
n_heads,
|
| 183 |
+
p_dropout=0.0,
|
| 184 |
+
window_size=None,
|
| 185 |
+
heads_share=True,
|
| 186 |
+
block_length=None,
|
| 187 |
+
proximal_bias=False,
|
| 188 |
+
proximal_init=False,
|
| 189 |
+
):
|
| 190 |
+
super().__init__()
|
| 191 |
+
assert channels % n_heads == 0
|
| 192 |
+
|
| 193 |
+
self.channels = channels
|
| 194 |
+
self.out_channels = out_channels
|
| 195 |
+
self.n_heads = n_heads
|
| 196 |
+
self.p_dropout = p_dropout
|
| 197 |
+
self.window_size = window_size
|
| 198 |
+
self.heads_share = heads_share
|
| 199 |
+
self.block_length = block_length
|
| 200 |
+
self.proximal_bias = proximal_bias
|
| 201 |
+
self.proximal_init = proximal_init
|
| 202 |
+
self.attn = None
|
| 203 |
+
|
| 204 |
+
self.k_channels = channels // n_heads
|
| 205 |
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
| 206 |
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
| 207 |
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
| 208 |
+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
| 209 |
+
self.drop = nn.Dropout(p_dropout)
|
| 210 |
+
|
| 211 |
+
if window_size is not None:
|
| 212 |
+
n_heads_rel = 1 if heads_share else n_heads
|
| 213 |
+
rel_stddev = self.k_channels**-0.5
|
| 214 |
+
self.emb_rel_k = nn.Parameter(
|
| 215 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
| 216 |
+
* rel_stddev
|
| 217 |
+
)
|
| 218 |
+
self.emb_rel_v = nn.Parameter(
|
| 219 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
| 220 |
+
* rel_stddev
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
nn.init.xavier_uniform_(self.conv_q.weight)
|
| 224 |
+
nn.init.xavier_uniform_(self.conv_k.weight)
|
| 225 |
+
nn.init.xavier_uniform_(self.conv_v.weight)
|
| 226 |
+
if proximal_init:
|
| 227 |
+
with torch.no_grad():
|
| 228 |
+
self.conv_k.weight.copy_(self.conv_q.weight)
|
| 229 |
+
self.conv_k.bias.copy_(self.conv_q.bias)
|
| 230 |
+
|
| 231 |
+
def forward(self, x, c, attn_mask=None):
|
| 232 |
+
q = self.conv_q(x)
|
| 233 |
+
k = self.conv_k(c)
|
| 234 |
+
v = self.conv_v(c)
|
| 235 |
+
|
| 236 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
| 237 |
+
|
| 238 |
+
x = self.conv_o(x)
|
| 239 |
+
return x
|
| 240 |
+
|
| 241 |
+
def attention(self, query, key, value, mask=None):
|
| 242 |
+
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
| 243 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
| 244 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
| 245 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
| 246 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
| 247 |
+
|
| 248 |
+
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
| 249 |
+
if self.window_size is not None:
|
| 250 |
+
assert (
|
| 251 |
+
t_s == t_t
|
| 252 |
+
), "Relative attention is only available for self-attention."
|
| 253 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
| 254 |
+
rel_logits = self._matmul_with_relative_keys(
|
| 255 |
+
query / math.sqrt(self.k_channels), key_relative_embeddings
|
| 256 |
+
)
|
| 257 |
+
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
| 258 |
+
scores = scores + scores_local
|
| 259 |
+
if self.proximal_bias:
|
| 260 |
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
| 261 |
+
scores = scores + self._attention_bias_proximal(t_s).to(
|
| 262 |
+
device=scores.device, dtype=scores.dtype
|
| 263 |
+
)
|
| 264 |
+
if mask is not None:
|
| 265 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
| 266 |
+
if self.block_length is not None:
|
| 267 |
+
assert (
|
| 268 |
+
t_s == t_t
|
| 269 |
+
), "Local attention is only available for self-attention."
|
| 270 |
+
block_mask = (
|
| 271 |
+
torch.ones_like(scores)
|
| 272 |
+
.triu(-self.block_length)
|
| 273 |
+
.tril(self.block_length)
|
| 274 |
+
)
|
| 275 |
+
scores = scores.masked_fill(block_mask == 0, -1e4)
|
| 276 |
+
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
| 277 |
+
p_attn = self.drop(p_attn)
|
| 278 |
+
output = torch.matmul(p_attn, value)
|
| 279 |
+
if self.window_size is not None:
|
| 280 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
| 281 |
+
value_relative_embeddings = self._get_relative_embeddings(
|
| 282 |
+
self.emb_rel_v, t_s
|
| 283 |
+
)
|
| 284 |
+
output = output + self._matmul_with_relative_values(
|
| 285 |
+
relative_weights, value_relative_embeddings
|
| 286 |
+
)
|
| 287 |
+
output = (
|
| 288 |
+
output.transpose(2, 3).contiguous().view(b, d, t_t)
|
| 289 |
+
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
| 290 |
+
return output, p_attn
|
| 291 |
+
|
| 292 |
+
def _matmul_with_relative_values(self, x, y):
|
| 293 |
+
"""
|
| 294 |
+
x: [b, h, l, m]
|
| 295 |
+
y: [h or 1, m, d]
|
| 296 |
+
ret: [b, h, l, d]
|
| 297 |
+
"""
|
| 298 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
| 299 |
+
return ret
|
| 300 |
+
|
| 301 |
+
def _matmul_with_relative_keys(self, x, y):
|
| 302 |
+
"""
|
| 303 |
+
x: [b, h, l, d]
|
| 304 |
+
y: [h or 1, m, d]
|
| 305 |
+
ret: [b, h, l, m]
|
| 306 |
+
"""
|
| 307 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
| 308 |
+
return ret
|
| 309 |
+
|
| 310 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
| 311 |
+
max_relative_position = 2 * self.window_size + 1
|
| 312 |
+
# Pad first before slice to avoid using cond ops.
|
| 313 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
| 314 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
| 315 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
| 316 |
+
if pad_length > 0:
|
| 317 |
+
padded_relative_embeddings = F.pad(
|
| 318 |
+
relative_embeddings,
|
| 319 |
+
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
| 320 |
+
)
|
| 321 |
+
else:
|
| 322 |
+
padded_relative_embeddings = relative_embeddings
|
| 323 |
+
used_relative_embeddings = padded_relative_embeddings[
|
| 324 |
+
:, slice_start_position:slice_end_position
|
| 325 |
+
]
|
| 326 |
+
return used_relative_embeddings
|
| 327 |
+
|
| 328 |
+
def _relative_position_to_absolute_position(self, x):
|
| 329 |
+
"""
|
| 330 |
+
x: [b, h, l, 2*l-1]
|
| 331 |
+
ret: [b, h, l, l]
|
| 332 |
+
"""
|
| 333 |
+
batch, heads, length, _ = x.size()
|
| 334 |
+
# Concat columns of pad to shift from relative to absolute indexing.
|
| 335 |
+
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
| 336 |
+
|
| 337 |
+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
| 338 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
| 339 |
+
x_flat = F.pad(
|
| 340 |
+
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
# Reshape and slice out the padded elements.
|
| 344 |
+
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
| 345 |
+
:, :, :length, length - 1 :
|
| 346 |
+
]
|
| 347 |
+
return x_final
|
| 348 |
+
|
| 349 |
+
def _absolute_position_to_relative_position(self, x):
|
| 350 |
+
"""
|
| 351 |
+
x: [b, h, l, l]
|
| 352 |
+
ret: [b, h, l, 2*l-1]
|
| 353 |
+
"""
|
| 354 |
+
batch, heads, length, _ = x.size()
|
| 355 |
+
# padd along column
|
| 356 |
+
x = F.pad(
|
| 357 |
+
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
| 358 |
+
)
|
| 359 |
+
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
| 360 |
+
# add 0's in the beginning that will skew the elements after reshape
|
| 361 |
+
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
| 362 |
+
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
| 363 |
+
return x_final
|
| 364 |
+
|
| 365 |
+
def _attention_bias_proximal(self, length):
|
| 366 |
+
"""Bias for self-attention to encourage attention to close positions.
|
| 367 |
+
Args:
|
| 368 |
+
length: an integer scalar.
|
| 369 |
+
Returns:
|
| 370 |
+
a Tensor with shape [1, 1, length, length]
|
| 371 |
+
"""
|
| 372 |
+
r = torch.arange(length, dtype=torch.float32)
|
| 373 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
| 374 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
class FFN(nn.Module):
|
| 378 |
+
def __init__(
|
| 379 |
+
self,
|
| 380 |
+
in_channels,
|
| 381 |
+
out_channels,
|
| 382 |
+
filter_channels,
|
| 383 |
+
kernel_size,
|
| 384 |
+
p_dropout=0.0,
|
| 385 |
+
activation=None,
|
| 386 |
+
causal=False,
|
| 387 |
+
):
|
| 388 |
+
super().__init__()
|
| 389 |
+
self.in_channels = in_channels
|
| 390 |
+
self.out_channels = out_channels
|
| 391 |
+
self.filter_channels = filter_channels
|
| 392 |
+
self.kernel_size = kernel_size
|
| 393 |
+
self.p_dropout = p_dropout
|
| 394 |
+
self.activation = activation
|
| 395 |
+
self.causal = causal
|
| 396 |
+
|
| 397 |
+
if causal:
|
| 398 |
+
self.padding = self._causal_padding
|
| 399 |
+
else:
|
| 400 |
+
self.padding = self._same_padding
|
| 401 |
+
|
| 402 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
| 403 |
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
| 404 |
+
self.drop = nn.Dropout(p_dropout)
|
| 405 |
+
|
| 406 |
+
def forward(self, x, x_mask):
|
| 407 |
+
x = self.conv_1(self.padding(x * x_mask))
|
| 408 |
+
if self.activation == "gelu":
|
| 409 |
+
x = x * torch.sigmoid(1.702 * x)
|
| 410 |
+
else:
|
| 411 |
+
x = torch.relu(x)
|
| 412 |
+
x = self.drop(x)
|
| 413 |
+
x = self.conv_2(self.padding(x * x_mask))
|
| 414 |
+
return x * x_mask
|
| 415 |
+
|
| 416 |
+
def _causal_padding(self, x):
|
| 417 |
+
if self.kernel_size == 1:
|
| 418 |
+
return x
|
| 419 |
+
pad_l = self.kernel_size - 1
|
| 420 |
+
pad_r = 0
|
| 421 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
| 422 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
| 423 |
+
return x
|
| 424 |
+
|
| 425 |
+
def _same_padding(self, x):
|
| 426 |
+
if self.kernel_size == 1:
|
| 427 |
+
return x
|
| 428 |
+
pad_l = (self.kernel_size - 1) // 2
|
| 429 |
+
pad_r = self.kernel_size // 2
|
| 430 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
| 431 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
| 432 |
+
return x
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
import torch.nn as nn
|
| 436 |
+
from torch.nn.utils import remove_weight_norm, weight_norm
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
class Depthwise_Separable_Conv1D(nn.Module):
|
| 440 |
+
def __init__(
|
| 441 |
+
self,
|
| 442 |
+
in_channels,
|
| 443 |
+
out_channels,
|
| 444 |
+
kernel_size,
|
| 445 |
+
stride=1,
|
| 446 |
+
padding=0,
|
| 447 |
+
dilation=1,
|
| 448 |
+
bias=True,
|
| 449 |
+
padding_mode="zeros", # TODO: refine this type
|
| 450 |
+
device=None,
|
| 451 |
+
dtype=None,
|
| 452 |
+
):
|
| 453 |
+
super().__init__()
|
| 454 |
+
self.depth_conv = nn.Conv1d(
|
| 455 |
+
in_channels=in_channels,
|
| 456 |
+
out_channels=in_channels,
|
| 457 |
+
kernel_size=kernel_size,
|
| 458 |
+
groups=in_channels,
|
| 459 |
+
stride=stride,
|
| 460 |
+
padding=padding,
|
| 461 |
+
dilation=dilation,
|
| 462 |
+
bias=bias,
|
| 463 |
+
padding_mode=padding_mode,
|
| 464 |
+
device=device,
|
| 465 |
+
dtype=dtype,
|
| 466 |
+
)
|
| 467 |
+
self.point_conv = nn.Conv1d(
|
| 468 |
+
in_channels=in_channels,
|
| 469 |
+
out_channels=out_channels,
|
| 470 |
+
kernel_size=1,
|
| 471 |
+
bias=bias,
|
| 472 |
+
device=device,
|
| 473 |
+
dtype=dtype,
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
def forward(self, input):
|
| 477 |
+
return self.point_conv(self.depth_conv(input))
|
| 478 |
+
|
| 479 |
+
def weight_norm(self):
|
| 480 |
+
self.depth_conv = weight_norm(self.depth_conv, name="weight")
|
| 481 |
+
self.point_conv = weight_norm(self.point_conv, name="weight")
|
| 482 |
+
|
| 483 |
+
def remove_weight_norm(self):
|
| 484 |
+
self.depth_conv = remove_weight_norm(self.depth_conv, name="weight")
|
| 485 |
+
self.point_conv = remove_weight_norm(self.point_conv, name="weight")
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
class Depthwise_Separable_TransposeConv1D(nn.Module):
|
| 489 |
+
def __init__(
|
| 490 |
+
self,
|
| 491 |
+
in_channels,
|
| 492 |
+
out_channels,
|
| 493 |
+
kernel_size,
|
| 494 |
+
stride=1,
|
| 495 |
+
padding=0,
|
| 496 |
+
output_padding=0,
|
| 497 |
+
bias=True,
|
| 498 |
+
dilation=1,
|
| 499 |
+
padding_mode="zeros", # TODO: refine this type
|
| 500 |
+
device=None,
|
| 501 |
+
dtype=None,
|
| 502 |
+
):
|
| 503 |
+
super().__init__()
|
| 504 |
+
self.depth_conv = nn.ConvTranspose1d(
|
| 505 |
+
in_channels=in_channels,
|
| 506 |
+
out_channels=in_channels,
|
| 507 |
+
kernel_size=kernel_size,
|
| 508 |
+
groups=in_channels,
|
| 509 |
+
stride=stride,
|
| 510 |
+
output_padding=output_padding,
|
| 511 |
+
padding=padding,
|
| 512 |
+
dilation=dilation,
|
| 513 |
+
bias=bias,
|
| 514 |
+
padding_mode=padding_mode,
|
| 515 |
+
device=device,
|
| 516 |
+
dtype=dtype,
|
| 517 |
+
)
|
| 518 |
+
self.point_conv = nn.Conv1d(
|
| 519 |
+
in_channels=in_channels,
|
| 520 |
+
out_channels=out_channels,
|
| 521 |
+
kernel_size=1,
|
| 522 |
+
bias=bias,
|
| 523 |
+
device=device,
|
| 524 |
+
dtype=dtype,
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
def forward(self, input):
|
| 528 |
+
return self.point_conv(self.depth_conv(input))
|
| 529 |
+
|
| 530 |
+
def weight_norm(self):
|
| 531 |
+
self.depth_conv = weight_norm(self.depth_conv, name="weight")
|
| 532 |
+
self.point_conv = weight_norm(self.point_conv, name="weight")
|
| 533 |
+
|
| 534 |
+
def remove_weight_norm(self):
|
| 535 |
+
remove_weight_norm(self.depth_conv, name="weight")
|
| 536 |
+
remove_weight_norm(self.point_conv, name="weight")
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def weight_norm_modules(module, name="weight", dim=0):
|
| 540 |
+
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
|
| 541 |
+
module, Depthwise_Separable_TransposeConv1D
|
| 542 |
+
):
|
| 543 |
+
module.weight_norm()
|
| 544 |
+
return module
|
| 545 |
+
else:
|
| 546 |
+
return weight_norm(module, name, dim)
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
def remove_weight_norm_modules(module, name="weight"):
|
| 550 |
+
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
|
| 551 |
+
module, Depthwise_Separable_TransposeConv1D
|
| 552 |
+
):
|
| 553 |
+
module.remove_weight_norm()
|
| 554 |
+
else:
|
| 555 |
+
remove_weight_norm(module, name)
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
class FFT(nn.Module):
|
| 559 |
+
def __init__(
|
| 560 |
+
self,
|
| 561 |
+
hidden_channels,
|
| 562 |
+
filter_channels,
|
| 563 |
+
n_heads,
|
| 564 |
+
n_layers=1,
|
| 565 |
+
kernel_size=1,
|
| 566 |
+
p_dropout=0.0,
|
| 567 |
+
proximal_bias=False,
|
| 568 |
+
proximal_init=True,
|
| 569 |
+
isflow=False,
|
| 570 |
+
**kwargs
|
| 571 |
+
):
|
| 572 |
+
super().__init__()
|
| 573 |
+
self.hidden_channels = hidden_channels
|
| 574 |
+
self.filter_channels = filter_channels
|
| 575 |
+
self.n_heads = n_heads
|
| 576 |
+
self.n_layers = n_layers
|
| 577 |
+
self.kernel_size = kernel_size
|
| 578 |
+
self.p_dropout = p_dropout
|
| 579 |
+
self.proximal_bias = proximal_bias
|
| 580 |
+
self.proximal_init = proximal_init
|
| 581 |
+
if isflow:
|
| 582 |
+
cond_layer = torch.nn.Conv1d(
|
| 583 |
+
kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
|
| 584 |
+
)
|
| 585 |
+
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
|
| 586 |
+
self.cond_layer = weight_norm_modules(cond_layer, name="weight")
|
| 587 |
+
self.gin_channels = kwargs["gin_channels"]
|
| 588 |
+
self.drop = nn.Dropout(p_dropout)
|
| 589 |
+
self.self_attn_layers = nn.ModuleList()
|
| 590 |
+
self.norm_layers_0 = nn.ModuleList()
|
| 591 |
+
self.ffn_layers = nn.ModuleList()
|
| 592 |
+
self.norm_layers_1 = nn.ModuleList()
|
| 593 |
+
for i in range(self.n_layers):
|
| 594 |
+
self.self_attn_layers.append(
|
| 595 |
+
MultiHeadAttention(
|
| 596 |
+
hidden_channels,
|
| 597 |
+
hidden_channels,
|
| 598 |
+
n_heads,
|
| 599 |
+
p_dropout=p_dropout,
|
| 600 |
+
proximal_bias=proximal_bias,
|
| 601 |
+
proximal_init=proximal_init,
|
| 602 |
+
)
|
| 603 |
+
)
|
| 604 |
+
self.norm_layers_0.append(LayerNorm(hidden_channels))
|
| 605 |
+
self.ffn_layers.append(
|
| 606 |
+
FFN(
|
| 607 |
+
hidden_channels,
|
| 608 |
+
hidden_channels,
|
| 609 |
+
filter_channels,
|
| 610 |
+
kernel_size,
|
| 611 |
+
p_dropout=p_dropout,
|
| 612 |
+
causal=True,
|
| 613 |
+
)
|
| 614 |
+
)
|
| 615 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
| 616 |
+
|
| 617 |
+
def forward(self, x, x_mask, g=None):
|
| 618 |
+
"""
|
| 619 |
+
x: decoder input
|
| 620 |
+
h: encoder output
|
| 621 |
+
"""
|
| 622 |
+
if g is not None:
|
| 623 |
+
g = self.cond_layer(g)
|
| 624 |
+
|
| 625 |
+
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
|
| 626 |
+
device=x.device, dtype=x.dtype
|
| 627 |
+
)
|
| 628 |
+
x = x * x_mask
|
| 629 |
+
for i in range(self.n_layers):
|
| 630 |
+
if g is not None:
|
| 631 |
+
x = self.cond_pre(x)
|
| 632 |
+
cond_offset = i * 2 * self.hidden_channels
|
| 633 |
+
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
| 634 |
+
x = commons.fused_add_tanh_sigmoid_multiply(
|
| 635 |
+
x, g_l, torch.IntTensor([self.hidden_channels])
|
| 636 |
+
)
|
| 637 |
+
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
| 638 |
+
y = self.drop(y)
|
| 639 |
+
x = self.norm_layers_0[i](x + y)
|
| 640 |
+
|
| 641 |
+
y = self.ffn_layers[i](x, x_mask)
|
| 642 |
+
y = self.drop(y)
|
| 643 |
+
x = self.norm_layers_1[i](x + y)
|
| 644 |
+
x = x * x_mask
|
| 645 |
+
return x
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
class TransformerCouplingLayer(nn.Module):
|
| 649 |
+
def __init__(
|
| 650 |
+
self,
|
| 651 |
+
channels,
|
| 652 |
+
hidden_channels,
|
| 653 |
+
kernel_size,
|
| 654 |
+
n_layers,
|
| 655 |
+
n_heads,
|
| 656 |
+
p_dropout=0,
|
| 657 |
+
filter_channels=0,
|
| 658 |
+
mean_only=False,
|
| 659 |
+
wn_sharing_parameter=None,
|
| 660 |
+
gin_channels=0,
|
| 661 |
+
):
|
| 662 |
+
assert channels % 2 == 0, "channels should be divisible by 2"
|
| 663 |
+
super().__init__()
|
| 664 |
+
self.channels = channels
|
| 665 |
+
self.hidden_channels = hidden_channels
|
| 666 |
+
self.kernel_size = kernel_size
|
| 667 |
+
self.n_layers = n_layers
|
| 668 |
+
self.half_channels = channels // 2
|
| 669 |
+
self.mean_only = mean_only
|
| 670 |
+
|
| 671 |
+
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
| 672 |
+
self.enc = (
|
| 673 |
+
Encoder(
|
| 674 |
+
hidden_channels,
|
| 675 |
+
filter_channels,
|
| 676 |
+
n_heads,
|
| 677 |
+
n_layers,
|
| 678 |
+
kernel_size,
|
| 679 |
+
p_dropout,
|
| 680 |
+
isflow=True,
|
| 681 |
+
gin_channels=gin_channels,
|
| 682 |
+
)
|
| 683 |
+
if wn_sharing_parameter is None
|
| 684 |
+
else wn_sharing_parameter
|
| 685 |
+
)
|
| 686 |
+
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
| 687 |
+
self.post.weight.data.zero_()
|
| 688 |
+
self.post.bias.data.zero_()
|
| 689 |
+
|
| 690 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
| 691 |
+
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
| 692 |
+
h = self.pre(x0) * x_mask
|
| 693 |
+
h = self.enc(h, x_mask, g=g)
|
| 694 |
+
stats = self.post(h) * x_mask
|
| 695 |
+
if not self.mean_only:
|
| 696 |
+
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
| 697 |
+
else:
|
| 698 |
+
m = stats
|
| 699 |
+
logs = torch.zeros_like(m)
|
| 700 |
+
|
| 701 |
+
if not reverse:
|
| 702 |
+
x1 = m + x1 * torch.exp(logs) * x_mask
|
| 703 |
+
x = torch.cat([x0, x1], 1)
|
| 704 |
+
logdet = torch.sum(logs, [1, 2])
|
| 705 |
+
return x, logdet
|
| 706 |
+
else:
|
| 707 |
+
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
| 708 |
+
x = torch.cat([x0, x1], 1)
|
| 709 |
+
return x
|
module/attentions_onnx.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
|
| 6 |
+
from module import commons
|
| 7 |
+
from module.modules import LayerNorm
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class LayerNorm(nn.Module):
|
| 11 |
+
def __init__(self, channels, eps=1e-5):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.channels = channels
|
| 14 |
+
self.eps = eps
|
| 15 |
+
|
| 16 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
| 17 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
x = x.transpose(1, -1)
|
| 21 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
| 22 |
+
return x.transpose(1, -1)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@torch.jit.script
|
| 26 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
| 27 |
+
n_channels_int = n_channels[0]
|
| 28 |
+
in_act = input_a + input_b
|
| 29 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
| 30 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
| 31 |
+
acts = t_act * s_act
|
| 32 |
+
return acts
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class Encoder(nn.Module):
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
hidden_channels,
|
| 39 |
+
filter_channels,
|
| 40 |
+
n_heads,
|
| 41 |
+
n_layers,
|
| 42 |
+
kernel_size=1,
|
| 43 |
+
p_dropout=0.0,
|
| 44 |
+
window_size=4,
|
| 45 |
+
isflow=True,
|
| 46 |
+
**kwargs
|
| 47 |
+
):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.hidden_channels = hidden_channels
|
| 50 |
+
self.filter_channels = filter_channels
|
| 51 |
+
self.n_heads = n_heads
|
| 52 |
+
self.n_layers = n_layers
|
| 53 |
+
self.kernel_size = kernel_size
|
| 54 |
+
self.p_dropout = p_dropout
|
| 55 |
+
self.window_size = window_size
|
| 56 |
+
# if isflow:
|
| 57 |
+
# cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
|
| 58 |
+
# self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
|
| 59 |
+
# self.cond_layer = weight_norm(cond_layer, name='weight')
|
| 60 |
+
# self.gin_channels = 256
|
| 61 |
+
self.cond_layer_idx = self.n_layers
|
| 62 |
+
if "gin_channels" in kwargs:
|
| 63 |
+
self.gin_channels = kwargs["gin_channels"]
|
| 64 |
+
if self.gin_channels != 0:
|
| 65 |
+
self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
|
| 66 |
+
# vits2 says 3rd block, so idx is 2 by default
|
| 67 |
+
self.cond_layer_idx = (
|
| 68 |
+
kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
|
| 69 |
+
)
|
| 70 |
+
logging.debug(self.gin_channels, self.cond_layer_idx)
|
| 71 |
+
assert (
|
| 72 |
+
self.cond_layer_idx < self.n_layers
|
| 73 |
+
), "cond_layer_idx should be less than n_layers"
|
| 74 |
+
self.drop = nn.Dropout(p_dropout)
|
| 75 |
+
self.attn_layers = nn.ModuleList()
|
| 76 |
+
self.norm_layers_1 = nn.ModuleList()
|
| 77 |
+
self.ffn_layers = nn.ModuleList()
|
| 78 |
+
self.norm_layers_2 = nn.ModuleList()
|
| 79 |
+
for i in range(self.n_layers):
|
| 80 |
+
self.attn_layers.append(
|
| 81 |
+
MultiHeadAttention(
|
| 82 |
+
hidden_channels,
|
| 83 |
+
hidden_channels,
|
| 84 |
+
n_heads,
|
| 85 |
+
p_dropout=p_dropout,
|
| 86 |
+
window_size=window_size,
|
| 87 |
+
)
|
| 88 |
+
)
|
| 89 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
| 90 |
+
self.ffn_layers.append(
|
| 91 |
+
FFN(
|
| 92 |
+
hidden_channels,
|
| 93 |
+
hidden_channels,
|
| 94 |
+
filter_channels,
|
| 95 |
+
kernel_size,
|
| 96 |
+
p_dropout=p_dropout,
|
| 97 |
+
)
|
| 98 |
+
)
|
| 99 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
| 100 |
+
|
| 101 |
+
def forward(self, x, x_mask, g=None):
|
| 102 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
| 103 |
+
x = x * x_mask
|
| 104 |
+
for i in range(self.n_layers):
|
| 105 |
+
if i == self.cond_layer_idx and g is not None:
|
| 106 |
+
g = self.spk_emb_linear(g.transpose(1, 2))
|
| 107 |
+
g = g.transpose(1, 2)
|
| 108 |
+
x = x + g
|
| 109 |
+
x = x * x_mask
|
| 110 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
| 111 |
+
y = self.drop(y)
|
| 112 |
+
x = self.norm_layers_1[i](x + y)
|
| 113 |
+
|
| 114 |
+
y = self.ffn_layers[i](x, x_mask)
|
| 115 |
+
y = self.drop(y)
|
| 116 |
+
x = self.norm_layers_2[i](x + y)
|
| 117 |
+
x = x * x_mask
|
| 118 |
+
return x
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class MultiHeadAttention(nn.Module):
|
| 122 |
+
def __init__(
|
| 123 |
+
self,
|
| 124 |
+
channels,
|
| 125 |
+
out_channels,
|
| 126 |
+
n_heads,
|
| 127 |
+
p_dropout=0.0,
|
| 128 |
+
window_size=None,
|
| 129 |
+
heads_share=True,
|
| 130 |
+
block_length=None,
|
| 131 |
+
proximal_bias=False,
|
| 132 |
+
proximal_init=False,
|
| 133 |
+
):
|
| 134 |
+
super().__init__()
|
| 135 |
+
assert channels % n_heads == 0
|
| 136 |
+
|
| 137 |
+
self.channels = channels
|
| 138 |
+
self.out_channels = out_channels
|
| 139 |
+
self.n_heads = n_heads
|
| 140 |
+
self.p_dropout = p_dropout
|
| 141 |
+
self.window_size = window_size
|
| 142 |
+
self.heads_share = heads_share
|
| 143 |
+
self.block_length = block_length
|
| 144 |
+
self.proximal_bias = proximal_bias
|
| 145 |
+
self.proximal_init = proximal_init
|
| 146 |
+
self.attn = None
|
| 147 |
+
|
| 148 |
+
self.k_channels = channels // n_heads
|
| 149 |
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
| 150 |
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
| 151 |
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
| 152 |
+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
| 153 |
+
self.drop = nn.Dropout(p_dropout)
|
| 154 |
+
|
| 155 |
+
if window_size is not None:
|
| 156 |
+
n_heads_rel = 1 if heads_share else n_heads
|
| 157 |
+
rel_stddev = self.k_channels**-0.5
|
| 158 |
+
self.emb_rel_k = nn.Parameter(
|
| 159 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
| 160 |
+
* rel_stddev
|
| 161 |
+
)
|
| 162 |
+
self.emb_rel_v = nn.Parameter(
|
| 163 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
| 164 |
+
* rel_stddev
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
nn.init.xavier_uniform_(self.conv_q.weight)
|
| 168 |
+
nn.init.xavier_uniform_(self.conv_k.weight)
|
| 169 |
+
nn.init.xavier_uniform_(self.conv_v.weight)
|
| 170 |
+
if proximal_init:
|
| 171 |
+
with torch.no_grad():
|
| 172 |
+
self.conv_k.weight.copy_(self.conv_q.weight)
|
| 173 |
+
self.conv_k.bias.copy_(self.conv_q.bias)
|
| 174 |
+
|
| 175 |
+
def forward(self, x, c, attn_mask=None):
|
| 176 |
+
q = self.conv_q(x)
|
| 177 |
+
k = self.conv_k(c)
|
| 178 |
+
v = self.conv_v(c)
|
| 179 |
+
|
| 180 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
| 181 |
+
|
| 182 |
+
x = self.conv_o(x)
|
| 183 |
+
return x
|
| 184 |
+
|
| 185 |
+
def attention(self, query, key, value, mask=None):
|
| 186 |
+
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
| 187 |
+
b, d, t_s, _ = (*key.size(), query.size(2))
|
| 188 |
+
query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
|
| 189 |
+
key = key.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
|
| 190 |
+
value = value.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
|
| 191 |
+
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
| 192 |
+
|
| 193 |
+
if self.window_size is not None:
|
| 194 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
| 195 |
+
rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
|
| 196 |
+
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
| 197 |
+
scores = scores + scores_local
|
| 198 |
+
|
| 199 |
+
if mask is not None:
|
| 200 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
| 201 |
+
|
| 202 |
+
p_attn = F.softmax(scores, dim=-1)
|
| 203 |
+
p_attn = self.drop(p_attn)
|
| 204 |
+
output = torch.matmul(p_attn, value)
|
| 205 |
+
|
| 206 |
+
if self.window_size is not None:
|
| 207 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
| 208 |
+
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
| 209 |
+
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
| 210 |
+
|
| 211 |
+
output = (output.transpose(2, 3).contiguous().view(b, d, -1))
|
| 212 |
+
return output, p_attn
|
| 213 |
+
|
| 214 |
+
def _matmul_with_relative_values(self, x, y):
|
| 215 |
+
"""
|
| 216 |
+
x: [b, h, l, m]
|
| 217 |
+
y: [h or 1, m, d]
|
| 218 |
+
ret: [b, h, l, d]
|
| 219 |
+
"""
|
| 220 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
| 221 |
+
return ret
|
| 222 |
+
|
| 223 |
+
def _matmul_with_relative_keys(self, x, y):
|
| 224 |
+
"""
|
| 225 |
+
x: [b, h, l, d]
|
| 226 |
+
y: [h or 1, m, d]
|
| 227 |
+
ret: [b, h, l, m]
|
| 228 |
+
"""
|
| 229 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
| 230 |
+
return ret
|
| 231 |
+
|
| 232 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
| 233 |
+
max_relative_position = 2 * self.window_size + 1
|
| 234 |
+
# Pad first before slice to avoid using cond ops.
|
| 235 |
+
pad_l = torch.zeros((1), dtype = torch.int64) + length - (self.window_size + 1)
|
| 236 |
+
pad_s = torch.zeros((1), dtype = torch.int64) + (self.window_size + 1) - length
|
| 237 |
+
pad_length = torch.max(pad_l, other=torch.zeros((1), dtype = torch.int64))
|
| 238 |
+
slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype = torch.int64))
|
| 239 |
+
|
| 240 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
| 241 |
+
padded_relative_embeddings = F.pad(
|
| 242 |
+
relative_embeddings,
|
| 243 |
+
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
| 244 |
+
)
|
| 245 |
+
used_relative_embeddings = padded_relative_embeddings[
|
| 246 |
+
:, slice_start_position:slice_end_position
|
| 247 |
+
]
|
| 248 |
+
return used_relative_embeddings
|
| 249 |
+
|
| 250 |
+
def _relative_position_to_absolute_position(self, x):
|
| 251 |
+
"""
|
| 252 |
+
x: [b, h, l, 2*l-1]
|
| 253 |
+
ret: [b, h, l, l]
|
| 254 |
+
"""
|
| 255 |
+
batch, heads, length, _ = x.size()
|
| 256 |
+
# Concat columns of pad to shift from relative to absolute indexing.
|
| 257 |
+
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
| 258 |
+
|
| 259 |
+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
| 260 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
| 261 |
+
x_flat = F.pad(
|
| 262 |
+
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# Reshape and slice out the padded elements.
|
| 266 |
+
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
| 267 |
+
:, :, :length, length - 1 :
|
| 268 |
+
]
|
| 269 |
+
return x_final
|
| 270 |
+
|
| 271 |
+
def _absolute_position_to_relative_position(self, x):
|
| 272 |
+
"""
|
| 273 |
+
x: [b, h, l, l]
|
| 274 |
+
ret: [b, h, l, 2*l-1]
|
| 275 |
+
"""
|
| 276 |
+
batch, heads, length, _ = x.size()
|
| 277 |
+
# padd along column
|
| 278 |
+
x = F.pad(
|
| 279 |
+
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
| 280 |
+
)
|
| 281 |
+
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
| 282 |
+
# add 0's in the beginning that will skew the elements after reshape
|
| 283 |
+
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
| 284 |
+
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
| 285 |
+
return x_final
|
| 286 |
+
|
| 287 |
+
def _attention_bias_proximal(self, length):
|
| 288 |
+
"""Bias for self-attention to encourage attention to close positions.
|
| 289 |
+
Args:
|
| 290 |
+
length: an integer scalar.
|
| 291 |
+
Returns:
|
| 292 |
+
a Tensor with shape [1, 1, length, length]
|
| 293 |
+
"""
|
| 294 |
+
r = torch.arange(length, dtype=torch.float32)
|
| 295 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
| 296 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class FFN(nn.Module):
|
| 300 |
+
def __init__(
|
| 301 |
+
self,
|
| 302 |
+
in_channels,
|
| 303 |
+
out_channels,
|
| 304 |
+
filter_channels,
|
| 305 |
+
kernel_size,
|
| 306 |
+
p_dropout=0.0,
|
| 307 |
+
activation=None,
|
| 308 |
+
causal=False,
|
| 309 |
+
):
|
| 310 |
+
super().__init__()
|
| 311 |
+
self.in_channels = in_channels
|
| 312 |
+
self.out_channels = out_channels
|
| 313 |
+
self.filter_channels = filter_channels
|
| 314 |
+
self.kernel_size = kernel_size
|
| 315 |
+
self.p_dropout = p_dropout
|
| 316 |
+
self.activation = activation
|
| 317 |
+
self.causal = causal
|
| 318 |
+
|
| 319 |
+
if causal:
|
| 320 |
+
self.padding = self._causal_padding
|
| 321 |
+
else:
|
| 322 |
+
self.padding = self._same_padding
|
| 323 |
+
|
| 324 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
| 325 |
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
| 326 |
+
self.drop = nn.Dropout(p_dropout)
|
| 327 |
+
|
| 328 |
+
def forward(self, x, x_mask):
|
| 329 |
+
x = self.conv_1(self.padding(x * x_mask))
|
| 330 |
+
if self.activation == "gelu":
|
| 331 |
+
x = x * torch.sigmoid(1.702 * x)
|
| 332 |
+
else:
|
| 333 |
+
x = torch.relu(x)
|
| 334 |
+
x = self.drop(x)
|
| 335 |
+
x = self.conv_2(self.padding(x * x_mask))
|
| 336 |
+
return x * x_mask
|
| 337 |
+
|
| 338 |
+
def _causal_padding(self, x):
|
| 339 |
+
if self.kernel_size == 1:
|
| 340 |
+
return x
|
| 341 |
+
pad_l = self.kernel_size - 1
|
| 342 |
+
pad_r = 0
|
| 343 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
| 344 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
| 345 |
+
return x
|
| 346 |
+
|
| 347 |
+
def _same_padding(self, x):
|
| 348 |
+
if self.kernel_size == 1:
|
| 349 |
+
return x
|
| 350 |
+
pad_l = (self.kernel_size - 1) // 2
|
| 351 |
+
pad_r = self.kernel_size // 2
|
| 352 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
| 353 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
| 354 |
+
return x
|
module/commons.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def init_weights(m, mean=0.0, std=0.01):
|
| 7 |
+
classname = m.__class__.__name__
|
| 8 |
+
if classname.find("Conv") != -1:
|
| 9 |
+
m.weight.data.normal_(mean, std)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_padding(kernel_size, dilation=1):
|
| 13 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def convert_pad_shape(pad_shape):
|
| 17 |
+
l = pad_shape[::-1]
|
| 18 |
+
pad_shape = [item for sublist in l for item in sublist]
|
| 19 |
+
return pad_shape
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def intersperse(lst, item):
|
| 23 |
+
result = [item] * (len(lst) * 2 + 1)
|
| 24 |
+
result[1::2] = lst
|
| 25 |
+
return result
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
| 29 |
+
"""KL(P||Q)"""
|
| 30 |
+
kl = (logs_q - logs_p) - 0.5
|
| 31 |
+
kl += (
|
| 32 |
+
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
| 33 |
+
)
|
| 34 |
+
return kl
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def rand_gumbel(shape):
|
| 38 |
+
"""Sample from the Gumbel distribution, protect from overflows."""
|
| 39 |
+
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
|
| 40 |
+
return -torch.log(-torch.log(uniform_samples))
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def rand_gumbel_like(x):
|
| 44 |
+
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
|
| 45 |
+
return g
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def slice_segments(x, ids_str, segment_size=4):
|
| 49 |
+
ret = torch.zeros_like(x[:, :, :segment_size])
|
| 50 |
+
for i in range(x.size(0)):
|
| 51 |
+
idx_str = ids_str[i]
|
| 52 |
+
idx_end = idx_str + segment_size
|
| 53 |
+
ret[i] = x[i, :, idx_str:idx_end]
|
| 54 |
+
return ret
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
| 58 |
+
b, d, t = x.size()
|
| 59 |
+
if x_lengths is None:
|
| 60 |
+
x_lengths = t
|
| 61 |
+
ids_str_max = x_lengths - segment_size + 1
|
| 62 |
+
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
| 63 |
+
ret = slice_segments(x, ids_str, segment_size)
|
| 64 |
+
return ret, ids_str
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
| 68 |
+
position = torch.arange(length, dtype=torch.float)
|
| 69 |
+
num_timescales = channels // 2
|
| 70 |
+
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
|
| 71 |
+
num_timescales - 1
|
| 72 |
+
)
|
| 73 |
+
inv_timescales = min_timescale * torch.exp(
|
| 74 |
+
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
|
| 75 |
+
)
|
| 76 |
+
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
| 77 |
+
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
| 78 |
+
signal = F.pad(signal, [0, 0, 0, channels % 2])
|
| 79 |
+
signal = signal.view(1, channels, length)
|
| 80 |
+
return signal
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
|
| 84 |
+
b, channels, length = x.size()
|
| 85 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
| 86 |
+
return x + signal.to(dtype=x.dtype, device=x.device)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
|
| 90 |
+
b, channels, length = x.size()
|
| 91 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
| 92 |
+
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def subsequent_mask(length):
|
| 96 |
+
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
| 97 |
+
return mask
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@torch.jit.script
|
| 101 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
| 102 |
+
n_channels_int = n_channels[0]
|
| 103 |
+
in_act = input_a + input_b
|
| 104 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
| 105 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
| 106 |
+
acts = t_act * s_act
|
| 107 |
+
return acts
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def convert_pad_shape(pad_shape):
|
| 111 |
+
l = pad_shape[::-1]
|
| 112 |
+
pad_shape = [item for sublist in l for item in sublist]
|
| 113 |
+
return pad_shape
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def shift_1d(x):
|
| 117 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
| 118 |
+
return x
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def sequence_mask(length, max_length=None):
|
| 122 |
+
if max_length is None:
|
| 123 |
+
max_length = length.max()
|
| 124 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
| 125 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def generate_path(duration, mask):
|
| 129 |
+
"""
|
| 130 |
+
duration: [b, 1, t_x]
|
| 131 |
+
mask: [b, 1, t_y, t_x]
|
| 132 |
+
"""
|
| 133 |
+
device = duration.device
|
| 134 |
+
|
| 135 |
+
b, _, t_y, t_x = mask.shape
|
| 136 |
+
cum_duration = torch.cumsum(duration, -1)
|
| 137 |
+
|
| 138 |
+
cum_duration_flat = cum_duration.view(b * t_x)
|
| 139 |
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
| 140 |
+
path = path.view(b, t_x, t_y)
|
| 141 |
+
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
| 142 |
+
path = path.unsqueeze(1).transpose(2, 3) * mask
|
| 143 |
+
return path
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
| 147 |
+
if isinstance(parameters, torch.Tensor):
|
| 148 |
+
parameters = [parameters]
|
| 149 |
+
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
| 150 |
+
norm_type = float(norm_type)
|
| 151 |
+
if clip_value is not None:
|
| 152 |
+
clip_value = float(clip_value)
|
| 153 |
+
|
| 154 |
+
total_norm = 0
|
| 155 |
+
for p in parameters:
|
| 156 |
+
param_norm = p.grad.data.norm(norm_type)
|
| 157 |
+
total_norm += param_norm.item() ** norm_type
|
| 158 |
+
if clip_value is not None:
|
| 159 |
+
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
| 160 |
+
total_norm = total_norm ** (1.0 / norm_type)
|
| 161 |
+
return total_norm
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def squeeze(x, x_mask=None, n_sqz=2):
|
| 165 |
+
b, c, t = x.size()
|
| 166 |
+
|
| 167 |
+
t = (t // n_sqz) * n_sqz
|
| 168 |
+
x = x[:, :, :t]
|
| 169 |
+
x_sqz = x.view(b, c, t // n_sqz, n_sqz)
|
| 170 |
+
x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
|
| 171 |
+
|
| 172 |
+
if x_mask is not None:
|
| 173 |
+
x_mask = x_mask[:, :, n_sqz - 1 :: n_sqz]
|
| 174 |
+
else:
|
| 175 |
+
x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
|
| 176 |
+
return x_sqz * x_mask, x_mask
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def unsqueeze(x, x_mask=None, n_sqz=2):
|
| 180 |
+
b, c, t = x.size()
|
| 181 |
+
|
| 182 |
+
x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
|
| 183 |
+
x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
|
| 184 |
+
|
| 185 |
+
if x_mask is not None:
|
| 186 |
+
x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
|
| 187 |
+
else:
|
| 188 |
+
x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
|
| 189 |
+
return x_unsqz * x_mask, x_mask
|
module/core_vq.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
# This implementation is inspired from
|
| 8 |
+
# https://github.com/lucidrains/vector-quantize-pytorch
|
| 9 |
+
# which is released under MIT License. Hereafter, the original license:
|
| 10 |
+
# MIT License
|
| 11 |
+
#
|
| 12 |
+
# Copyright (c) 2020 Phil Wang
|
| 13 |
+
#
|
| 14 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 15 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 16 |
+
# in the Software without restriction, including without limitation the rights
|
| 17 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 18 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 19 |
+
# furnished to do so, subject to the following conditions:
|
| 20 |
+
#
|
| 21 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 22 |
+
# copies or substantial portions of the Software.
|
| 23 |
+
#
|
| 24 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 25 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 26 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 27 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 28 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 29 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 30 |
+
# SOFTWARE.
|
| 31 |
+
|
| 32 |
+
"""Core vector quantization implementation."""
|
| 33 |
+
import typing as tp
|
| 34 |
+
|
| 35 |
+
from einops import rearrange, repeat
|
| 36 |
+
import torch
|
| 37 |
+
from torch import nn
|
| 38 |
+
import torch.nn.functional as F
|
| 39 |
+
from tqdm import tqdm
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
| 43 |
+
return val if val is not None else d
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def ema_inplace(moving_avg, new, decay: float):
|
| 47 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
| 51 |
+
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def uniform_init(*shape: int):
|
| 55 |
+
t = torch.empty(shape)
|
| 56 |
+
nn.init.kaiming_uniform_(t)
|
| 57 |
+
return t
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def sample_vectors(samples, num: int):
|
| 61 |
+
num_samples, device = samples.shape[0], samples.device
|
| 62 |
+
|
| 63 |
+
if num_samples >= num:
|
| 64 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
| 65 |
+
else:
|
| 66 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
| 67 |
+
|
| 68 |
+
return samples[indices]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
| 72 |
+
dim, dtype = samples.shape[-1], samples.dtype
|
| 73 |
+
max_kmeans_samples = 500
|
| 74 |
+
samples = samples[:max_kmeans_samples, :]
|
| 75 |
+
means = sample_vectors(samples, num_clusters)
|
| 76 |
+
|
| 77 |
+
print("kmeans start ... ")
|
| 78 |
+
for _ in tqdm(range(num_iters)):
|
| 79 |
+
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
|
| 80 |
+
dists = -(diffs**2).sum(dim=-1)
|
| 81 |
+
|
| 82 |
+
buckets = dists.max(dim=-1).indices
|
| 83 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
| 84 |
+
zero_mask = bins == 0
|
| 85 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
| 86 |
+
|
| 87 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
| 88 |
+
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
| 89 |
+
new_means = new_means / bins_min_clamped[..., None]
|
| 90 |
+
|
| 91 |
+
means = torch.where(zero_mask[..., None], means, new_means)
|
| 92 |
+
|
| 93 |
+
return means, bins
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class EuclideanCodebook(nn.Module):
|
| 97 |
+
"""Codebook with Euclidean distance.
|
| 98 |
+
Args:
|
| 99 |
+
dim (int): Dimension.
|
| 100 |
+
codebook_size (int): Codebook size.
|
| 101 |
+
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
| 102 |
+
If set to true, run the k-means algorithm on the first training batch and use
|
| 103 |
+
the learned centroids as initialization.
|
| 104 |
+
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
| 105 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
| 106 |
+
epsilon (float): Epsilon value for numerical stability.
|
| 107 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
| 108 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
| 109 |
+
randomly selected vector from the current batch.
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
dim: int,
|
| 115 |
+
codebook_size: int,
|
| 116 |
+
kmeans_init: int = False,
|
| 117 |
+
kmeans_iters: int = 10,
|
| 118 |
+
decay: float = 0.99,
|
| 119 |
+
epsilon: float = 1e-5,
|
| 120 |
+
threshold_ema_dead_code: int = 2,
|
| 121 |
+
):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.decay = decay
|
| 124 |
+
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
|
| 125 |
+
uniform_init if not kmeans_init else torch.zeros
|
| 126 |
+
)
|
| 127 |
+
embed = init_fn(codebook_size, dim)
|
| 128 |
+
|
| 129 |
+
self.codebook_size = codebook_size
|
| 130 |
+
|
| 131 |
+
self.kmeans_iters = kmeans_iters
|
| 132 |
+
self.epsilon = epsilon
|
| 133 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 134 |
+
|
| 135 |
+
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
|
| 136 |
+
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
| 137 |
+
self.register_buffer("embed", embed)
|
| 138 |
+
self.register_buffer("embed_avg", embed.clone())
|
| 139 |
+
|
| 140 |
+
@torch.jit.ignore
|
| 141 |
+
def init_embed_(self, data):
|
| 142 |
+
if self.inited:
|
| 143 |
+
return
|
| 144 |
+
|
| 145 |
+
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
| 146 |
+
self.embed.data.copy_(embed)
|
| 147 |
+
self.embed_avg.data.copy_(embed.clone())
|
| 148 |
+
self.cluster_size.data.copy_(cluster_size)
|
| 149 |
+
self.inited.data.copy_(torch.Tensor([True]))
|
| 150 |
+
# Make sure all buffers across workers are in sync after initialization
|
| 151 |
+
# broadcast_tensors(self.buffers())
|
| 152 |
+
|
| 153 |
+
def replace_(self, samples, mask):
|
| 154 |
+
modified_codebook = torch.where(
|
| 155 |
+
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
| 156 |
+
)
|
| 157 |
+
self.embed.data.copy_(modified_codebook)
|
| 158 |
+
|
| 159 |
+
def expire_codes_(self, batch_samples):
|
| 160 |
+
if self.threshold_ema_dead_code == 0:
|
| 161 |
+
return
|
| 162 |
+
|
| 163 |
+
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
| 164 |
+
if not torch.any(expired_codes):
|
| 165 |
+
return
|
| 166 |
+
|
| 167 |
+
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
| 168 |
+
self.replace_(batch_samples, mask=expired_codes)
|
| 169 |
+
# broadcast_tensors(self.buffers())
|
| 170 |
+
|
| 171 |
+
def preprocess(self, x):
|
| 172 |
+
x = rearrange(x, "... d -> (...) d")
|
| 173 |
+
return x
|
| 174 |
+
|
| 175 |
+
def quantize(self, x):
|
| 176 |
+
embed = self.embed.t()
|
| 177 |
+
dist = -(
|
| 178 |
+
x.pow(2).sum(1, keepdim=True)
|
| 179 |
+
- 2 * x @ embed
|
| 180 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
| 181 |
+
)
|
| 182 |
+
embed_ind = dist.max(dim=-1).indices
|
| 183 |
+
return embed_ind
|
| 184 |
+
|
| 185 |
+
def postprocess_emb(self, embed_ind, shape):
|
| 186 |
+
return embed_ind.view(*shape[:-1])
|
| 187 |
+
|
| 188 |
+
def dequantize(self, embed_ind):
|
| 189 |
+
quantize = F.embedding(embed_ind, self.embed)
|
| 190 |
+
return quantize
|
| 191 |
+
|
| 192 |
+
def encode(self, x):
|
| 193 |
+
shape = x.shape
|
| 194 |
+
# pre-process
|
| 195 |
+
x = self.preprocess(x)
|
| 196 |
+
# quantize
|
| 197 |
+
embed_ind = self.quantize(x)
|
| 198 |
+
# post-process
|
| 199 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
| 200 |
+
return embed_ind
|
| 201 |
+
|
| 202 |
+
def decode(self, embed_ind):
|
| 203 |
+
quantize = self.dequantize(embed_ind)
|
| 204 |
+
return quantize
|
| 205 |
+
|
| 206 |
+
def forward(self, x):
|
| 207 |
+
shape, dtype = x.shape, x.dtype
|
| 208 |
+
x = self.preprocess(x)
|
| 209 |
+
|
| 210 |
+
self.init_embed_(x)
|
| 211 |
+
|
| 212 |
+
embed_ind = self.quantize(x)
|
| 213 |
+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
| 214 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
| 215 |
+
quantize = self.dequantize(embed_ind)
|
| 216 |
+
|
| 217 |
+
if self.training:
|
| 218 |
+
# We do the expiry of code at that point as buffers are in sync
|
| 219 |
+
# and all the workers will take the same decision.
|
| 220 |
+
self.expire_codes_(x)
|
| 221 |
+
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
| 222 |
+
embed_sum = x.t() @ embed_onehot
|
| 223 |
+
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
| 224 |
+
cluster_size = (
|
| 225 |
+
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
| 226 |
+
* self.cluster_size.sum()
|
| 227 |
+
)
|
| 228 |
+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
| 229 |
+
self.embed.data.copy_(embed_normalized)
|
| 230 |
+
|
| 231 |
+
return quantize, embed_ind
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class VectorQuantization(nn.Module):
|
| 235 |
+
"""Vector quantization implementation.
|
| 236 |
+
Currently supports only euclidean distance.
|
| 237 |
+
Args:
|
| 238 |
+
dim (int): Dimension
|
| 239 |
+
codebook_size (int): Codebook size
|
| 240 |
+
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
| 241 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
| 242 |
+
epsilon (float): Epsilon value for numerical stability.
|
| 243 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
| 244 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
| 245 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
| 246 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
| 247 |
+
randomly selected vector from the current batch.
|
| 248 |
+
commitment_weight (float): Weight for commitment loss.
|
| 249 |
+
"""
|
| 250 |
+
|
| 251 |
+
def __init__(
|
| 252 |
+
self,
|
| 253 |
+
dim: int,
|
| 254 |
+
codebook_size: int,
|
| 255 |
+
codebook_dim: tp.Optional[int] = None,
|
| 256 |
+
decay: float = 0.99,
|
| 257 |
+
epsilon: float = 1e-5,
|
| 258 |
+
kmeans_init: bool = True,
|
| 259 |
+
kmeans_iters: int = 50,
|
| 260 |
+
threshold_ema_dead_code: int = 2,
|
| 261 |
+
commitment_weight: float = 1.0,
|
| 262 |
+
):
|
| 263 |
+
super().__init__()
|
| 264 |
+
_codebook_dim: int = default(codebook_dim, dim)
|
| 265 |
+
|
| 266 |
+
requires_projection = _codebook_dim != dim
|
| 267 |
+
self.project_in = (
|
| 268 |
+
nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
|
| 269 |
+
)
|
| 270 |
+
self.project_out = (
|
| 271 |
+
nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
self.epsilon = epsilon
|
| 275 |
+
self.commitment_weight = commitment_weight
|
| 276 |
+
|
| 277 |
+
self._codebook = EuclideanCodebook(
|
| 278 |
+
dim=_codebook_dim,
|
| 279 |
+
codebook_size=codebook_size,
|
| 280 |
+
kmeans_init=kmeans_init,
|
| 281 |
+
kmeans_iters=kmeans_iters,
|
| 282 |
+
decay=decay,
|
| 283 |
+
epsilon=epsilon,
|
| 284 |
+
threshold_ema_dead_code=threshold_ema_dead_code,
|
| 285 |
+
)
|
| 286 |
+
self.codebook_size = codebook_size
|
| 287 |
+
|
| 288 |
+
@property
|
| 289 |
+
def codebook(self):
|
| 290 |
+
return self._codebook.embed
|
| 291 |
+
|
| 292 |
+
def encode(self, x):
|
| 293 |
+
x = rearrange(x, "b d n -> b n d")
|
| 294 |
+
x = self.project_in(x)
|
| 295 |
+
embed_in = self._codebook.encode(x)
|
| 296 |
+
return embed_in
|
| 297 |
+
|
| 298 |
+
def decode(self, embed_ind):
|
| 299 |
+
quantize = self._codebook.decode(embed_ind)
|
| 300 |
+
quantize = self.project_out(quantize)
|
| 301 |
+
quantize = rearrange(quantize, "b n d -> b d n")
|
| 302 |
+
return quantize
|
| 303 |
+
|
| 304 |
+
def forward(self, x):
|
| 305 |
+
device = x.device
|
| 306 |
+
x = rearrange(x, "b d n -> b n d")
|
| 307 |
+
x = self.project_in(x)
|
| 308 |
+
|
| 309 |
+
quantize, embed_ind = self._codebook(x)
|
| 310 |
+
|
| 311 |
+
if self.training:
|
| 312 |
+
quantize = x + (quantize - x).detach()
|
| 313 |
+
|
| 314 |
+
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
| 315 |
+
|
| 316 |
+
if self.training:
|
| 317 |
+
if self.commitment_weight > 0:
|
| 318 |
+
commit_loss = F.mse_loss(quantize.detach(), x)
|
| 319 |
+
loss = loss + commit_loss * self.commitment_weight
|
| 320 |
+
|
| 321 |
+
quantize = self.project_out(quantize)
|
| 322 |
+
quantize = rearrange(quantize, "b n d -> b d n")
|
| 323 |
+
return quantize, embed_ind, loss
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class ResidualVectorQuantization(nn.Module):
|
| 327 |
+
"""Residual vector quantization implementation.
|
| 328 |
+
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
| 329 |
+
"""
|
| 330 |
+
|
| 331 |
+
def __init__(self, *, num_quantizers, **kwargs):
|
| 332 |
+
super().__init__()
|
| 333 |
+
self.layers = nn.ModuleList(
|
| 334 |
+
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
def forward(
|
| 338 |
+
self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
|
| 339 |
+
):
|
| 340 |
+
quantized_out = 0.0
|
| 341 |
+
residual = x
|
| 342 |
+
|
| 343 |
+
all_losses = []
|
| 344 |
+
all_indices = []
|
| 345 |
+
out_quantized = []
|
| 346 |
+
|
| 347 |
+
n_q = n_q or len(self.layers)
|
| 348 |
+
|
| 349 |
+
for i, layer in enumerate(self.layers[:n_q]):
|
| 350 |
+
quantized, indices, loss = layer(residual)
|
| 351 |
+
residual = residual - quantized
|
| 352 |
+
quantized_out = quantized_out + quantized
|
| 353 |
+
|
| 354 |
+
all_indices.append(indices)
|
| 355 |
+
all_losses.append(loss)
|
| 356 |
+
if layers and i in layers:
|
| 357 |
+
out_quantized.append(quantized)
|
| 358 |
+
|
| 359 |
+
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
| 360 |
+
return quantized_out, out_indices, out_losses, out_quantized
|
| 361 |
+
|
| 362 |
+
def encode(
|
| 363 |
+
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
|
| 364 |
+
) -> torch.Tensor:
|
| 365 |
+
residual = x
|
| 366 |
+
all_indices = []
|
| 367 |
+
n_q = n_q or len(self.layers)
|
| 368 |
+
st = st or 0
|
| 369 |
+
for layer in self.layers[st:n_q]:
|
| 370 |
+
indices = layer.encode(residual)
|
| 371 |
+
quantized = layer.decode(indices)
|
| 372 |
+
residual = residual - quantized
|
| 373 |
+
all_indices.append(indices)
|
| 374 |
+
out_indices = torch.stack(all_indices)
|
| 375 |
+
return out_indices
|
| 376 |
+
|
| 377 |
+
def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
|
| 378 |
+
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
| 379 |
+
for i, indices in enumerate(q_indices):
|
| 380 |
+
layer = self.layers[st + i]
|
| 381 |
+
quantized = layer.decode(indices)
|
| 382 |
+
quantized_out = quantized_out + quantized
|
| 383 |
+
return quantized_out
|
module/data_utils.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import traceback
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.utils.data
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
from module import commons
|
| 12 |
+
from module.mel_processing import spectrogram_torch
|
| 13 |
+
from text import cleaned_text_to_sequence
|
| 14 |
+
from utils import load_wav_to_torch, load_filepaths_and_text
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from functools import lru_cache
|
| 17 |
+
import requests
|
| 18 |
+
from scipy.io import wavfile
|
| 19 |
+
from io import BytesIO
|
| 20 |
+
from tools.my_utils import load_audio
|
| 21 |
+
version = os.environ.get('version',None)
|
| 22 |
+
# ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79)
|
| 23 |
+
class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
| 24 |
+
"""
|
| 25 |
+
1) loads audio, speaker_id, text pairs
|
| 26 |
+
2) normalizes text and converts them to sequences of integers
|
| 27 |
+
3) computes spectrograms from audio files.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, hparams, val=False):
|
| 31 |
+
exp_dir = hparams.exp_dir
|
| 32 |
+
self.path2 = "%s/2-name2text.txt" % exp_dir
|
| 33 |
+
self.path4 = "%s/4-cnhubert" % exp_dir
|
| 34 |
+
self.path5 = "%s/5-wav32k" % exp_dir
|
| 35 |
+
assert os.path.exists(self.path2)
|
| 36 |
+
assert os.path.exists(self.path4)
|
| 37 |
+
assert os.path.exists(self.path5)
|
| 38 |
+
names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
|
| 39 |
+
names5 = set(os.listdir(self.path5))
|
| 40 |
+
self.phoneme_data = {}
|
| 41 |
+
with open(self.path2, "r", encoding="utf8") as f:
|
| 42 |
+
lines = f.read().strip("\n").split("\n")
|
| 43 |
+
|
| 44 |
+
for line in lines:
|
| 45 |
+
tmp = line.split("\t")
|
| 46 |
+
if (len(tmp) != 4):
|
| 47 |
+
continue
|
| 48 |
+
self.phoneme_data[tmp[0]] = [tmp[1]]
|
| 49 |
+
|
| 50 |
+
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
|
| 51 |
+
tmp = self.audiopaths_sid_text
|
| 52 |
+
leng = len(tmp)
|
| 53 |
+
min_num = 100
|
| 54 |
+
if (leng < min_num):
|
| 55 |
+
self.audiopaths_sid_text = []
|
| 56 |
+
for _ in range(max(2, int(min_num / leng))):
|
| 57 |
+
self.audiopaths_sid_text += tmp
|
| 58 |
+
self.max_wav_value = hparams.max_wav_value
|
| 59 |
+
self.sampling_rate = hparams.sampling_rate
|
| 60 |
+
self.filter_length = hparams.filter_length
|
| 61 |
+
self.hop_length = hparams.hop_length
|
| 62 |
+
self.win_length = hparams.win_length
|
| 63 |
+
self.sampling_rate = hparams.sampling_rate
|
| 64 |
+
self.val = val
|
| 65 |
+
|
| 66 |
+
random.seed(1234)
|
| 67 |
+
random.shuffle(self.audiopaths_sid_text)
|
| 68 |
+
|
| 69 |
+
print("phoneme_data_len:", len(self.phoneme_data.keys()))
|
| 70 |
+
print("wav_data_len:", len(self.audiopaths_sid_text))
|
| 71 |
+
|
| 72 |
+
audiopaths_sid_text_new = []
|
| 73 |
+
lengths = []
|
| 74 |
+
skipped_phone = 0
|
| 75 |
+
skipped_dur = 0
|
| 76 |
+
for audiopath in tqdm(self.audiopaths_sid_text):
|
| 77 |
+
try:
|
| 78 |
+
phoneme = self.phoneme_data[audiopath][0]
|
| 79 |
+
phoneme = phoneme.split(' ')
|
| 80 |
+
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
|
| 81 |
+
except Exception:
|
| 82 |
+
print(f"{audiopath} not in self.phoneme_data !")
|
| 83 |
+
skipped_phone += 1
|
| 84 |
+
continue
|
| 85 |
+
|
| 86 |
+
size = os.path.getsize("%s/%s" % (self.path5, audiopath))
|
| 87 |
+
duration = size / self.sampling_rate / 2
|
| 88 |
+
|
| 89 |
+
if duration == 0:
|
| 90 |
+
print(f"Zero duration for {audiopath}, skipping...")
|
| 91 |
+
skipped_dur += 1
|
| 92 |
+
continue
|
| 93 |
+
|
| 94 |
+
if 54 > duration > 0.6 or self.val:
|
| 95 |
+
audiopaths_sid_text_new.append([audiopath, phoneme_ids])
|
| 96 |
+
lengths.append(size // (2 * self.hop_length))
|
| 97 |
+
else:
|
| 98 |
+
skipped_dur += 1
|
| 99 |
+
continue
|
| 100 |
+
|
| 101 |
+
print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
|
| 102 |
+
print("total left: ", len(audiopaths_sid_text_new))
|
| 103 |
+
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
|
| 104 |
+
self.audiopaths_sid_text = audiopaths_sid_text_new
|
| 105 |
+
self.lengths = lengths
|
| 106 |
+
|
| 107 |
+
def get_audio_text_speaker_pair(self, audiopath_sid_text):
|
| 108 |
+
audiopath, phoneme_ids = audiopath_sid_text
|
| 109 |
+
text = torch.FloatTensor(phoneme_ids)
|
| 110 |
+
try:
|
| 111 |
+
spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
|
| 112 |
+
with torch.no_grad():
|
| 113 |
+
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
|
| 114 |
+
if (ssl.shape[-1] != spec.shape[-1]):
|
| 115 |
+
typee = ssl.dtype
|
| 116 |
+
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
| 117 |
+
ssl.requires_grad = False
|
| 118 |
+
except:
|
| 119 |
+
traceback.print_exc()
|
| 120 |
+
spec = torch.zeros(1025, 100)
|
| 121 |
+
wav = torch.zeros(1, 100 * self.hop_length)
|
| 122 |
+
ssl = torch.zeros(1, 768, 100)
|
| 123 |
+
text = text[-1:]
|
| 124 |
+
print("load audio or ssl error!!!!!!", audiopath)
|
| 125 |
+
return (ssl, spec, wav, text)
|
| 126 |
+
|
| 127 |
+
def get_audio(self, filename):
|
| 128 |
+
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
| 129 |
+
audio = torch.FloatTensor(audio_array) # /32768
|
| 130 |
+
audio_norm = audio
|
| 131 |
+
audio_norm = audio_norm.unsqueeze(0)
|
| 132 |
+
spec = spectrogram_torch(audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length,
|
| 133 |
+
center=False)
|
| 134 |
+
spec = torch.squeeze(spec, 0)
|
| 135 |
+
return spec, audio_norm
|
| 136 |
+
|
| 137 |
+
def get_sid(self, sid):
|
| 138 |
+
sid = torch.LongTensor([int(sid)])
|
| 139 |
+
return sid
|
| 140 |
+
|
| 141 |
+
def __getitem__(self, index):
|
| 142 |
+
# with torch.no_grad():
|
| 143 |
+
return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
|
| 144 |
+
|
| 145 |
+
def __len__(self):
|
| 146 |
+
return len(self.audiopaths_sid_text)
|
| 147 |
+
|
| 148 |
+
def random_slice(self, ssl, wav, mel):
|
| 149 |
+
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
|
| 150 |
+
"first", ssl.shape, wav.shape)
|
| 151 |
+
|
| 152 |
+
len_mel = mel.shape[1]
|
| 153 |
+
if self.val:
|
| 154 |
+
reference_mel = mel[:, :len_mel // 3]
|
| 155 |
+
return reference_mel, ssl, wav, mel
|
| 156 |
+
dir = random.randint(0, 1)
|
| 157 |
+
sep_point = random.randint(int(len_mel // 3), int(len_mel // 3 * 2))
|
| 158 |
+
|
| 159 |
+
if dir == 0:
|
| 160 |
+
reference_mel = mel[:, :sep_point]
|
| 161 |
+
ssl = ssl[:, :, sep_point:]
|
| 162 |
+
wav2 = wav[:, sep_point * self.hop_length:]
|
| 163 |
+
mel = mel[:, sep_point:]
|
| 164 |
+
else:
|
| 165 |
+
reference_mel = mel[:, sep_point:]
|
| 166 |
+
ssl = ssl[:, :, :sep_point]
|
| 167 |
+
wav2 = wav[:, :sep_point * self.hop_length]
|
| 168 |
+
mel = mel[:, :sep_point]
|
| 169 |
+
|
| 170 |
+
assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
|
| 171 |
+
ssl.shape, wav.shape, wav2.shape, mel.shape, sep_point, self.hop_length, sep_point * self.hop_length, dir)
|
| 172 |
+
return reference_mel, ssl, wav2, mel
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class TextAudioSpeakerCollate():
|
| 176 |
+
""" Zero-pads model inputs and targets
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
def __init__(self, return_ids=False):
|
| 180 |
+
self.return_ids = return_ids
|
| 181 |
+
|
| 182 |
+
def __call__(self, batch):
|
| 183 |
+
"""Collate's training batch from normalized text, audio and speaker identities
|
| 184 |
+
PARAMS
|
| 185 |
+
------
|
| 186 |
+
batch: [text_normalized, spec_normalized, wav_normalized, sid]
|
| 187 |
+
"""
|
| 188 |
+
# Right zero-pad all one-hot text sequences to max input length
|
| 189 |
+
_, ids_sorted_decreasing = torch.sort(
|
| 190 |
+
torch.LongTensor([x[1].size(1) for x in batch]),
|
| 191 |
+
dim=0, descending=True)
|
| 192 |
+
|
| 193 |
+
max_ssl_len = max([x[0].size(2) for x in batch])
|
| 194 |
+
max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
|
| 195 |
+
max_spec_len = max([x[1].size(1) for x in batch])
|
| 196 |
+
max_spec_len = int(2 * ((max_spec_len // 2) + 1))
|
| 197 |
+
max_wav_len = max([x[2].size(1) for x in batch])
|
| 198 |
+
max_text_len = max([x[3].size(0) for x in batch])
|
| 199 |
+
|
| 200 |
+
ssl_lengths = torch.LongTensor(len(batch))
|
| 201 |
+
spec_lengths = torch.LongTensor(len(batch))
|
| 202 |
+
wav_lengths = torch.LongTensor(len(batch))
|
| 203 |
+
text_lengths = torch.LongTensor(len(batch))
|
| 204 |
+
|
| 205 |
+
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
|
| 206 |
+
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
|
| 207 |
+
ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
|
| 208 |
+
text_padded = torch.LongTensor(len(batch), max_text_len)
|
| 209 |
+
|
| 210 |
+
spec_padded.zero_()
|
| 211 |
+
wav_padded.zero_()
|
| 212 |
+
ssl_padded.zero_()
|
| 213 |
+
text_padded.zero_()
|
| 214 |
+
|
| 215 |
+
for i in range(len(ids_sorted_decreasing)):
|
| 216 |
+
row = batch[ids_sorted_decreasing[i]]
|
| 217 |
+
|
| 218 |
+
ssl = row[0]
|
| 219 |
+
ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
|
| 220 |
+
ssl_lengths[i] = ssl.size(2)
|
| 221 |
+
|
| 222 |
+
spec = row[1]
|
| 223 |
+
spec_padded[i, :, :spec.size(1)] = spec
|
| 224 |
+
spec_lengths[i] = spec.size(1)
|
| 225 |
+
|
| 226 |
+
wav = row[2]
|
| 227 |
+
wav_padded[i, :, :wav.size(1)] = wav
|
| 228 |
+
wav_lengths[i] = wav.size(1)
|
| 229 |
+
|
| 230 |
+
text = row[3]
|
| 231 |
+
text_padded[i, :text.size(0)] = text
|
| 232 |
+
text_lengths[i] = text.size(0)
|
| 233 |
+
|
| 234 |
+
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
| 238 |
+
"""
|
| 239 |
+
Maintain similar input lengths in a batch.
|
| 240 |
+
Length groups are specified by boundaries.
|
| 241 |
+
Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
|
| 242 |
+
|
| 243 |
+
It removes samples which are not included in the boundaries.
|
| 244 |
+
Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
|
| 248 |
+
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
| 249 |
+
self.lengths = dataset.lengths
|
| 250 |
+
self.batch_size = batch_size
|
| 251 |
+
self.boundaries = boundaries
|
| 252 |
+
|
| 253 |
+
self.buckets, self.num_samples_per_bucket = self._create_buckets()
|
| 254 |
+
self.total_size = sum(self.num_samples_per_bucket)
|
| 255 |
+
self.num_samples = self.total_size // self.num_replicas
|
| 256 |
+
|
| 257 |
+
def _create_buckets(self):
|
| 258 |
+
buckets = [[] for _ in range(len(self.boundaries) - 1)]
|
| 259 |
+
for i in range(len(self.lengths)):
|
| 260 |
+
length = self.lengths[i]
|
| 261 |
+
idx_bucket = self._bisect(length)
|
| 262 |
+
if idx_bucket != -1:
|
| 263 |
+
buckets[idx_bucket].append(i)
|
| 264 |
+
|
| 265 |
+
i = len(buckets) - 1
|
| 266 |
+
while i >= 0:
|
| 267 |
+
if len(buckets[i]) == 0:
|
| 268 |
+
buckets.pop(i)
|
| 269 |
+
self.boundaries.pop(i + 1)
|
| 270 |
+
i -= 1
|
| 271 |
+
|
| 272 |
+
num_samples_per_bucket = []
|
| 273 |
+
for i in range(len(buckets)):
|
| 274 |
+
len_bucket = len(buckets[i])
|
| 275 |
+
total_batch_size = self.num_replicas * self.batch_size
|
| 276 |
+
rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size
|
| 277 |
+
num_samples_per_bucket.append(len_bucket + rem)
|
| 278 |
+
return buckets, num_samples_per_bucket
|
| 279 |
+
|
| 280 |
+
def __iter__(self):
|
| 281 |
+
g = torch.Generator()
|
| 282 |
+
g.manual_seed(self.epoch)
|
| 283 |
+
|
| 284 |
+
indices = []
|
| 285 |
+
if self.shuffle:
|
| 286 |
+
for bucket in self.buckets:
|
| 287 |
+
indices.append(torch.randperm(len(bucket), generator=g).tolist())
|
| 288 |
+
else:
|
| 289 |
+
for bucket in self.buckets:
|
| 290 |
+
indices.append(list(range(len(bucket))))
|
| 291 |
+
|
| 292 |
+
batches = []
|
| 293 |
+
for i in range(len(self.buckets)):
|
| 294 |
+
bucket = self.buckets[i]
|
| 295 |
+
len_bucket = len(bucket)
|
| 296 |
+
ids_bucket = indices[i]
|
| 297 |
+
num_samples_bucket = self.num_samples_per_bucket[i]
|
| 298 |
+
|
| 299 |
+
rem = num_samples_bucket - len_bucket
|
| 300 |
+
ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)]
|
| 301 |
+
|
| 302 |
+
ids_bucket = ids_bucket[self.rank::self.num_replicas]
|
| 303 |
+
|
| 304 |
+
for j in range(len(ids_bucket) // self.batch_size):
|
| 305 |
+
batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size:(j + 1) * self.batch_size]]
|
| 306 |
+
batches.append(batch)
|
| 307 |
+
|
| 308 |
+
if self.shuffle:
|
| 309 |
+
batch_ids = torch.randperm(len(batches), generator=g).tolist()
|
| 310 |
+
batches = [batches[i] for i in batch_ids]
|
| 311 |
+
self.batches = batches
|
| 312 |
+
|
| 313 |
+
assert len(self.batches) * self.batch_size == self.num_samples
|
| 314 |
+
return iter(self.batches)
|
| 315 |
+
|
| 316 |
+
def _bisect(self, x, lo=0, hi=None):
|
| 317 |
+
if hi is None:
|
| 318 |
+
hi = len(self.boundaries) - 1
|
| 319 |
+
|
| 320 |
+
if hi > lo:
|
| 321 |
+
mid = (hi + lo) // 2
|
| 322 |
+
if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
|
| 323 |
+
return mid
|
| 324 |
+
elif x <= self.boundaries[mid]:
|
| 325 |
+
return self._bisect(x, lo, mid)
|
| 326 |
+
else:
|
| 327 |
+
return self._bisect(x, mid + 1, hi)
|
| 328 |
+
else:
|
| 329 |
+
return -1
|
| 330 |
+
|
| 331 |
+
def __len__(self):
|
| 332 |
+
return self.num_samples // self.batch_size
|
module/losses.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def feature_loss(fmap_r, fmap_g):
|
| 8 |
+
loss = 0
|
| 9 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
| 10 |
+
for rl, gl in zip(dr, dg):
|
| 11 |
+
rl = rl.float().detach()
|
| 12 |
+
gl = gl.float()
|
| 13 |
+
loss += torch.mean(torch.abs(rl - gl))
|
| 14 |
+
|
| 15 |
+
return loss * 2
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
| 19 |
+
loss = 0
|
| 20 |
+
r_losses = []
|
| 21 |
+
g_losses = []
|
| 22 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
| 23 |
+
dr = dr.float()
|
| 24 |
+
dg = dg.float()
|
| 25 |
+
r_loss = torch.mean((1 - dr) ** 2)
|
| 26 |
+
g_loss = torch.mean(dg**2)
|
| 27 |
+
loss += r_loss + g_loss
|
| 28 |
+
r_losses.append(r_loss.item())
|
| 29 |
+
g_losses.append(g_loss.item())
|
| 30 |
+
|
| 31 |
+
return loss, r_losses, g_losses
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def generator_loss(disc_outputs):
|
| 35 |
+
loss = 0
|
| 36 |
+
gen_losses = []
|
| 37 |
+
for dg in disc_outputs:
|
| 38 |
+
dg = dg.float()
|
| 39 |
+
l = torch.mean((1 - dg) ** 2)
|
| 40 |
+
gen_losses.append(l)
|
| 41 |
+
loss += l
|
| 42 |
+
|
| 43 |
+
return loss, gen_losses
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
| 47 |
+
"""
|
| 48 |
+
z_p, logs_q: [b, h, t_t]
|
| 49 |
+
m_p, logs_p: [b, h, t_t]
|
| 50 |
+
"""
|
| 51 |
+
z_p = z_p.float()
|
| 52 |
+
logs_q = logs_q.float()
|
| 53 |
+
m_p = m_p.float()
|
| 54 |
+
logs_p = logs_p.float()
|
| 55 |
+
z_mask = z_mask.float()
|
| 56 |
+
|
| 57 |
+
kl = logs_p - logs_q - 0.5
|
| 58 |
+
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
|
| 59 |
+
kl = torch.sum(kl * z_mask)
|
| 60 |
+
l = kl / torch.sum(z_mask)
|
| 61 |
+
return l
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def mle_loss(z, m, logs, logdet, mask):
|
| 65 |
+
l = torch.sum(logs) + 0.5 * torch.sum(
|
| 66 |
+
torch.exp(-2 * logs) * ((z - m) ** 2)
|
| 67 |
+
) # neg normal likelihood w/o the constant term
|
| 68 |
+
l = l - torch.sum(logdet) # log jacobian determinant
|
| 69 |
+
l = l / torch.sum(
|
| 70 |
+
torch.ones_like(z) * mask
|
| 71 |
+
) # averaging across batch, channel and time axes
|
| 72 |
+
l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
|
| 73 |
+
return l
|
module/mel_processing.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch.utils.data
|
| 8 |
+
import numpy as np
|
| 9 |
+
import librosa
|
| 10 |
+
import librosa.util as librosa_util
|
| 11 |
+
from librosa.util import normalize, pad_center, tiny
|
| 12 |
+
from scipy.signal import get_window
|
| 13 |
+
from scipy.io.wavfile import read
|
| 14 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 15 |
+
|
| 16 |
+
MAX_WAV_VALUE = 32768.0
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
| 20 |
+
"""
|
| 21 |
+
PARAMS
|
| 22 |
+
------
|
| 23 |
+
C: compression factor
|
| 24 |
+
"""
|
| 25 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def dynamic_range_decompression_torch(x, C=1):
|
| 29 |
+
"""
|
| 30 |
+
PARAMS
|
| 31 |
+
------
|
| 32 |
+
C: compression factor used to compress
|
| 33 |
+
"""
|
| 34 |
+
return torch.exp(x) / C
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def spectral_normalize_torch(magnitudes):
|
| 38 |
+
output = dynamic_range_compression_torch(magnitudes)
|
| 39 |
+
return output
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def spectral_de_normalize_torch(magnitudes):
|
| 43 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
| 44 |
+
return output
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
mel_basis = {}
|
| 48 |
+
hann_window = {}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
| 52 |
+
if torch.min(y) < -1.0:
|
| 53 |
+
print("min value is ", torch.min(y))
|
| 54 |
+
if torch.max(y) > 1.0:
|
| 55 |
+
print("max value is ", torch.max(y))
|
| 56 |
+
|
| 57 |
+
global hann_window
|
| 58 |
+
dtype_device = str(y.dtype) + "_" + str(y.device)
|
| 59 |
+
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
| 60 |
+
if wnsize_dtype_device not in hann_window:
|
| 61 |
+
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
|
| 62 |
+
dtype=y.dtype, device=y.device
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
y = torch.nn.functional.pad(
|
| 66 |
+
y.unsqueeze(1),
|
| 67 |
+
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
| 68 |
+
mode="reflect",
|
| 69 |
+
)
|
| 70 |
+
y = y.squeeze(1)
|
| 71 |
+
spec = torch.stft(
|
| 72 |
+
y,
|
| 73 |
+
n_fft,
|
| 74 |
+
hop_length=hop_size,
|
| 75 |
+
win_length=win_size,
|
| 76 |
+
window=hann_window[wnsize_dtype_device],
|
| 77 |
+
center=center,
|
| 78 |
+
pad_mode="reflect",
|
| 79 |
+
normalized=False,
|
| 80 |
+
onesided=True,
|
| 81 |
+
return_complex=False,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
| 85 |
+
return spec
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
| 89 |
+
global mel_basis
|
| 90 |
+
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
| 91 |
+
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
| 92 |
+
if fmax_dtype_device not in mel_basis:
|
| 93 |
+
mel = librosa_mel_fn(
|
| 94 |
+
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
| 95 |
+
)
|
| 96 |
+
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
|
| 97 |
+
dtype=spec.dtype, device=spec.device
|
| 98 |
+
)
|
| 99 |
+
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
| 100 |
+
spec = spectral_normalize_torch(spec)
|
| 101 |
+
return spec
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def mel_spectrogram_torch(
|
| 105 |
+
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
|
| 106 |
+
):
|
| 107 |
+
if torch.min(y) < -1.0:
|
| 108 |
+
print("min value is ", torch.min(y))
|
| 109 |
+
if torch.max(y) > 1.0:
|
| 110 |
+
print("max value is ", torch.max(y))
|
| 111 |
+
|
| 112 |
+
global mel_basis, hann_window
|
| 113 |
+
dtype_device = str(y.dtype) + "_" + str(y.device)
|
| 114 |
+
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
| 115 |
+
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
| 116 |
+
if fmax_dtype_device not in mel_basis:
|
| 117 |
+
mel = librosa_mel_fn(
|
| 118 |
+
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
| 119 |
+
)
|
| 120 |
+
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
|
| 121 |
+
dtype=y.dtype, device=y.device
|
| 122 |
+
)
|
| 123 |
+
if wnsize_dtype_device not in hann_window:
|
| 124 |
+
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
|
| 125 |
+
dtype=y.dtype, device=y.device
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
y = torch.nn.functional.pad(
|
| 129 |
+
y.unsqueeze(1),
|
| 130 |
+
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
| 131 |
+
mode="reflect",
|
| 132 |
+
)
|
| 133 |
+
y = y.squeeze(1)
|
| 134 |
+
|
| 135 |
+
spec = torch.stft(
|
| 136 |
+
y,
|
| 137 |
+
n_fft,
|
| 138 |
+
hop_length=hop_size,
|
| 139 |
+
win_length=win_size,
|
| 140 |
+
window=hann_window[wnsize_dtype_device],
|
| 141 |
+
center=center,
|
| 142 |
+
pad_mode="reflect",
|
| 143 |
+
normalized=False,
|
| 144 |
+
onesided=True,
|
| 145 |
+
return_complex=False,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
| 149 |
+
|
| 150 |
+
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
| 151 |
+
spec = spectral_normalize_torch(spec)
|
| 152 |
+
|
| 153 |
+
return spec
|
module/models.py
ADDED
|
@@ -0,0 +1,1040 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
warnings.filterwarnings("ignore")
|
| 3 |
+
import copy
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import pdb
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torch.nn import functional as F
|
| 11 |
+
|
| 12 |
+
from module import commons
|
| 13 |
+
from module import modules
|
| 14 |
+
from module import attentions
|
| 15 |
+
|
| 16 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
| 17 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
| 18 |
+
from module.commons import init_weights, get_padding
|
| 19 |
+
from module.mrte_model import MRTE
|
| 20 |
+
from module.quantize import ResidualVectorQuantizer
|
| 21 |
+
# from text import symbols
|
| 22 |
+
from text import symbols as symbols_v1
|
| 23 |
+
from text import symbols2 as symbols_v2
|
| 24 |
+
from torch.cuda.amp import autocast
|
| 25 |
+
import contextlib
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class StochasticDurationPredictor(nn.Module):
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
in_channels,
|
| 32 |
+
filter_channels,
|
| 33 |
+
kernel_size,
|
| 34 |
+
p_dropout,
|
| 35 |
+
n_flows=4,
|
| 36 |
+
gin_channels=0,
|
| 37 |
+
):
|
| 38 |
+
super().__init__()
|
| 39 |
+
filter_channels = in_channels # it needs to be removed from future version.
|
| 40 |
+
self.in_channels = in_channels
|
| 41 |
+
self.filter_channels = filter_channels
|
| 42 |
+
self.kernel_size = kernel_size
|
| 43 |
+
self.p_dropout = p_dropout
|
| 44 |
+
self.n_flows = n_flows
|
| 45 |
+
self.gin_channels = gin_channels
|
| 46 |
+
|
| 47 |
+
self.log_flow = modules.Log()
|
| 48 |
+
self.flows = nn.ModuleList()
|
| 49 |
+
self.flows.append(modules.ElementwiseAffine(2))
|
| 50 |
+
for i in range(n_flows):
|
| 51 |
+
self.flows.append(
|
| 52 |
+
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
| 53 |
+
)
|
| 54 |
+
self.flows.append(modules.Flip())
|
| 55 |
+
|
| 56 |
+
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
| 57 |
+
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
| 58 |
+
self.post_convs = modules.DDSConv(
|
| 59 |
+
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
| 60 |
+
)
|
| 61 |
+
self.post_flows = nn.ModuleList()
|
| 62 |
+
self.post_flows.append(modules.ElementwiseAffine(2))
|
| 63 |
+
for i in range(4):
|
| 64 |
+
self.post_flows.append(
|
| 65 |
+
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
| 66 |
+
)
|
| 67 |
+
self.post_flows.append(modules.Flip())
|
| 68 |
+
|
| 69 |
+
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
| 70 |
+
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
| 71 |
+
self.convs = modules.DDSConv(
|
| 72 |
+
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
| 73 |
+
)
|
| 74 |
+
if gin_channels != 0:
|
| 75 |
+
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
| 76 |
+
|
| 77 |
+
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
|
| 78 |
+
x = torch.detach(x)
|
| 79 |
+
x = self.pre(x)
|
| 80 |
+
if g is not None:
|
| 81 |
+
g = torch.detach(g)
|
| 82 |
+
x = x + self.cond(g)
|
| 83 |
+
x = self.convs(x, x_mask)
|
| 84 |
+
x = self.proj(x) * x_mask
|
| 85 |
+
|
| 86 |
+
if not reverse:
|
| 87 |
+
flows = self.flows
|
| 88 |
+
assert w is not None
|
| 89 |
+
|
| 90 |
+
logdet_tot_q = 0
|
| 91 |
+
h_w = self.post_pre(w)
|
| 92 |
+
h_w = self.post_convs(h_w, x_mask)
|
| 93 |
+
h_w = self.post_proj(h_w) * x_mask
|
| 94 |
+
e_q = (
|
| 95 |
+
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
|
| 96 |
+
* x_mask
|
| 97 |
+
)
|
| 98 |
+
z_q = e_q
|
| 99 |
+
for flow in self.post_flows:
|
| 100 |
+
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
| 101 |
+
logdet_tot_q += logdet_q
|
| 102 |
+
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
| 103 |
+
u = torch.sigmoid(z_u) * x_mask
|
| 104 |
+
z0 = (w - u) * x_mask
|
| 105 |
+
logdet_tot_q += torch.sum(
|
| 106 |
+
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
|
| 107 |
+
)
|
| 108 |
+
logq = (
|
| 109 |
+
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
|
| 110 |
+
- logdet_tot_q
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
logdet_tot = 0
|
| 114 |
+
z0, logdet = self.log_flow(z0, x_mask)
|
| 115 |
+
logdet_tot += logdet
|
| 116 |
+
z = torch.cat([z0, z1], 1)
|
| 117 |
+
for flow in flows:
|
| 118 |
+
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
| 119 |
+
logdet_tot = logdet_tot + logdet
|
| 120 |
+
nll = (
|
| 121 |
+
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
|
| 122 |
+
- logdet_tot
|
| 123 |
+
)
|
| 124 |
+
return nll + logq # [b]
|
| 125 |
+
else:
|
| 126 |
+
flows = list(reversed(self.flows))
|
| 127 |
+
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
| 128 |
+
z = (
|
| 129 |
+
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
|
| 130 |
+
* noise_scale
|
| 131 |
+
)
|
| 132 |
+
for flow in flows:
|
| 133 |
+
z = flow(z, x_mask, g=x, reverse=reverse)
|
| 134 |
+
z0, z1 = torch.split(z, [1, 1], 1)
|
| 135 |
+
logw = z0
|
| 136 |
+
return logw
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class DurationPredictor(nn.Module):
|
| 140 |
+
def __init__(
|
| 141 |
+
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
|
| 142 |
+
):
|
| 143 |
+
super().__init__()
|
| 144 |
+
|
| 145 |
+
self.in_channels = in_channels
|
| 146 |
+
self.filter_channels = filter_channels
|
| 147 |
+
self.kernel_size = kernel_size
|
| 148 |
+
self.p_dropout = p_dropout
|
| 149 |
+
self.gin_channels = gin_channels
|
| 150 |
+
|
| 151 |
+
self.drop = nn.Dropout(p_dropout)
|
| 152 |
+
self.conv_1 = nn.Conv1d(
|
| 153 |
+
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 154 |
+
)
|
| 155 |
+
self.norm_1 = modules.LayerNorm(filter_channels)
|
| 156 |
+
self.conv_2 = nn.Conv1d(
|
| 157 |
+
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 158 |
+
)
|
| 159 |
+
self.norm_2 = modules.LayerNorm(filter_channels)
|
| 160 |
+
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
| 161 |
+
|
| 162 |
+
if gin_channels != 0:
|
| 163 |
+
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
| 164 |
+
|
| 165 |
+
def forward(self, x, x_mask, g=None):
|
| 166 |
+
x = torch.detach(x)
|
| 167 |
+
if g is not None:
|
| 168 |
+
g = torch.detach(g)
|
| 169 |
+
x = x + self.cond(g)
|
| 170 |
+
x = self.conv_1(x * x_mask)
|
| 171 |
+
x = torch.relu(x)
|
| 172 |
+
x = self.norm_1(x)
|
| 173 |
+
x = self.drop(x)
|
| 174 |
+
x = self.conv_2(x * x_mask)
|
| 175 |
+
x = torch.relu(x)
|
| 176 |
+
x = self.norm_2(x)
|
| 177 |
+
x = self.drop(x)
|
| 178 |
+
x = self.proj(x * x_mask)
|
| 179 |
+
return x * x_mask
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class TextEncoder(nn.Module):
|
| 183 |
+
def __init__(
|
| 184 |
+
self,
|
| 185 |
+
out_channels,
|
| 186 |
+
hidden_channels,
|
| 187 |
+
filter_channels,
|
| 188 |
+
n_heads,
|
| 189 |
+
n_layers,
|
| 190 |
+
kernel_size,
|
| 191 |
+
p_dropout,
|
| 192 |
+
latent_channels=192,
|
| 193 |
+
version = "v2",
|
| 194 |
+
):
|
| 195 |
+
super().__init__()
|
| 196 |
+
self.out_channels = out_channels
|
| 197 |
+
self.hidden_channels = hidden_channels
|
| 198 |
+
self.filter_channels = filter_channels
|
| 199 |
+
self.n_heads = n_heads
|
| 200 |
+
self.n_layers = n_layers
|
| 201 |
+
self.kernel_size = kernel_size
|
| 202 |
+
self.p_dropout = p_dropout
|
| 203 |
+
self.latent_channels = latent_channels
|
| 204 |
+
self.version = version
|
| 205 |
+
|
| 206 |
+
self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
|
| 207 |
+
|
| 208 |
+
self.encoder_ssl = attentions.Encoder(
|
| 209 |
+
hidden_channels,
|
| 210 |
+
filter_channels,
|
| 211 |
+
n_heads,
|
| 212 |
+
n_layers // 2,
|
| 213 |
+
kernel_size,
|
| 214 |
+
p_dropout,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
self.encoder_text = attentions.Encoder(
|
| 218 |
+
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
if self.version == "v1":
|
| 222 |
+
symbols = symbols_v1.symbols
|
| 223 |
+
else:
|
| 224 |
+
symbols = symbols_v2.symbols
|
| 225 |
+
self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
|
| 226 |
+
|
| 227 |
+
self.mrte = MRTE()
|
| 228 |
+
|
| 229 |
+
self.encoder2 = attentions.Encoder(
|
| 230 |
+
hidden_channels,
|
| 231 |
+
filter_channels,
|
| 232 |
+
n_heads,
|
| 233 |
+
n_layers // 2,
|
| 234 |
+
kernel_size,
|
| 235 |
+
p_dropout,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
| 239 |
+
|
| 240 |
+
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1,test=None):
|
| 241 |
+
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
|
| 242 |
+
y.dtype
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
y = self.ssl_proj(y * y_mask) * y_mask
|
| 246 |
+
|
| 247 |
+
y = self.encoder_ssl(y * y_mask, y_mask)
|
| 248 |
+
|
| 249 |
+
text_mask = torch.unsqueeze(
|
| 250 |
+
commons.sequence_mask(text_lengths, text.size(1)), 1
|
| 251 |
+
).to(y.dtype)
|
| 252 |
+
if test == 1:
|
| 253 |
+
text[:, :] = 0
|
| 254 |
+
text = self.text_embedding(text).transpose(1, 2)
|
| 255 |
+
text = self.encoder_text(text * text_mask, text_mask)
|
| 256 |
+
y = self.mrte(y, y_mask, text, text_mask, ge)
|
| 257 |
+
y = self.encoder2(y * y_mask, y_mask)
|
| 258 |
+
if(speed!=1):
|
| 259 |
+
y = F.interpolate(y, size=int(y.shape[-1] / speed)+1, mode="linear")
|
| 260 |
+
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
|
| 261 |
+
stats = self.proj(y) * y_mask
|
| 262 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
| 263 |
+
return y, m, logs, y_mask
|
| 264 |
+
|
| 265 |
+
def extract_latent(self, x):
|
| 266 |
+
x = self.ssl_proj(x)
|
| 267 |
+
quantized, codes, commit_loss, quantized_list = self.quantizer(x)
|
| 268 |
+
return codes.transpose(0, 1)
|
| 269 |
+
|
| 270 |
+
def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
|
| 271 |
+
quantized = self.quantizer.decode(codes)
|
| 272 |
+
|
| 273 |
+
y = self.vq_proj(quantized) * y_mask
|
| 274 |
+
y = self.encoder_ssl(y * y_mask, y_mask)
|
| 275 |
+
|
| 276 |
+
y = self.mrte(y, y_mask, refer, refer_mask, ge)
|
| 277 |
+
|
| 278 |
+
y = self.encoder2(y * y_mask, y_mask)
|
| 279 |
+
|
| 280 |
+
stats = self.proj(y) * y_mask
|
| 281 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
| 282 |
+
return y, m, logs, y_mask, quantized
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class ResidualCouplingBlock(nn.Module):
|
| 286 |
+
def __init__(
|
| 287 |
+
self,
|
| 288 |
+
channels,
|
| 289 |
+
hidden_channels,
|
| 290 |
+
kernel_size,
|
| 291 |
+
dilation_rate,
|
| 292 |
+
n_layers,
|
| 293 |
+
n_flows=4,
|
| 294 |
+
gin_channels=0,
|
| 295 |
+
):
|
| 296 |
+
super().__init__()
|
| 297 |
+
self.channels = channels
|
| 298 |
+
self.hidden_channels = hidden_channels
|
| 299 |
+
self.kernel_size = kernel_size
|
| 300 |
+
self.dilation_rate = dilation_rate
|
| 301 |
+
self.n_layers = n_layers
|
| 302 |
+
self.n_flows = n_flows
|
| 303 |
+
self.gin_channels = gin_channels
|
| 304 |
+
|
| 305 |
+
self.flows = nn.ModuleList()
|
| 306 |
+
for i in range(n_flows):
|
| 307 |
+
self.flows.append(
|
| 308 |
+
modules.ResidualCouplingLayer(
|
| 309 |
+
channels,
|
| 310 |
+
hidden_channels,
|
| 311 |
+
kernel_size,
|
| 312 |
+
dilation_rate,
|
| 313 |
+
n_layers,
|
| 314 |
+
gin_channels=gin_channels,
|
| 315 |
+
mean_only=True,
|
| 316 |
+
)
|
| 317 |
+
)
|
| 318 |
+
self.flows.append(modules.Flip())
|
| 319 |
+
|
| 320 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
| 321 |
+
if not reverse:
|
| 322 |
+
for flow in self.flows:
|
| 323 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
| 324 |
+
else:
|
| 325 |
+
for flow in reversed(self.flows):
|
| 326 |
+
x = flow(x, x_mask, g=g, reverse=reverse)
|
| 327 |
+
return x
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class PosteriorEncoder(nn.Module):
|
| 331 |
+
def __init__(
|
| 332 |
+
self,
|
| 333 |
+
in_channels,
|
| 334 |
+
out_channels,
|
| 335 |
+
hidden_channels,
|
| 336 |
+
kernel_size,
|
| 337 |
+
dilation_rate,
|
| 338 |
+
n_layers,
|
| 339 |
+
gin_channels=0,
|
| 340 |
+
):
|
| 341 |
+
super().__init__()
|
| 342 |
+
self.in_channels = in_channels
|
| 343 |
+
self.out_channels = out_channels
|
| 344 |
+
self.hidden_channels = hidden_channels
|
| 345 |
+
self.kernel_size = kernel_size
|
| 346 |
+
self.dilation_rate = dilation_rate
|
| 347 |
+
self.n_layers = n_layers
|
| 348 |
+
self.gin_channels = gin_channels
|
| 349 |
+
|
| 350 |
+
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
| 351 |
+
self.enc = modules.WN(
|
| 352 |
+
hidden_channels,
|
| 353 |
+
kernel_size,
|
| 354 |
+
dilation_rate,
|
| 355 |
+
n_layers,
|
| 356 |
+
gin_channels=gin_channels,
|
| 357 |
+
)
|
| 358 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
| 359 |
+
|
| 360 |
+
def forward(self, x, x_lengths, g=None):
|
| 361 |
+
if g != None:
|
| 362 |
+
g = g.detach()
|
| 363 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
| 364 |
+
x.dtype
|
| 365 |
+
)
|
| 366 |
+
x = self.pre(x) * x_mask
|
| 367 |
+
x = self.enc(x, x_mask, g=g)
|
| 368 |
+
stats = self.proj(x) * x_mask
|
| 369 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
| 370 |
+
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
| 371 |
+
return z, m, logs, x_mask
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
class WNEncoder(nn.Module):
|
| 375 |
+
def __init__(
|
| 376 |
+
self,
|
| 377 |
+
in_channels,
|
| 378 |
+
out_channels,
|
| 379 |
+
hidden_channels,
|
| 380 |
+
kernel_size,
|
| 381 |
+
dilation_rate,
|
| 382 |
+
n_layers,
|
| 383 |
+
gin_channels=0,
|
| 384 |
+
):
|
| 385 |
+
super().__init__()
|
| 386 |
+
self.in_channels = in_channels
|
| 387 |
+
self.out_channels = out_channels
|
| 388 |
+
self.hidden_channels = hidden_channels
|
| 389 |
+
self.kernel_size = kernel_size
|
| 390 |
+
self.dilation_rate = dilation_rate
|
| 391 |
+
self.n_layers = n_layers
|
| 392 |
+
self.gin_channels = gin_channels
|
| 393 |
+
|
| 394 |
+
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
| 395 |
+
self.enc = modules.WN(
|
| 396 |
+
hidden_channels,
|
| 397 |
+
kernel_size,
|
| 398 |
+
dilation_rate,
|
| 399 |
+
n_layers,
|
| 400 |
+
gin_channels=gin_channels,
|
| 401 |
+
)
|
| 402 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
| 403 |
+
self.norm = modules.LayerNorm(out_channels)
|
| 404 |
+
|
| 405 |
+
def forward(self, x, x_lengths, g=None):
|
| 406 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
| 407 |
+
x.dtype
|
| 408 |
+
)
|
| 409 |
+
x = self.pre(x) * x_mask
|
| 410 |
+
x = self.enc(x, x_mask, g=g)
|
| 411 |
+
out = self.proj(x) * x_mask
|
| 412 |
+
out = self.norm(out)
|
| 413 |
+
return out
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
class Generator(torch.nn.Module):
|
| 417 |
+
def __init__(
|
| 418 |
+
self,
|
| 419 |
+
initial_channel,
|
| 420 |
+
resblock,
|
| 421 |
+
resblock_kernel_sizes,
|
| 422 |
+
resblock_dilation_sizes,
|
| 423 |
+
upsample_rates,
|
| 424 |
+
upsample_initial_channel,
|
| 425 |
+
upsample_kernel_sizes,
|
| 426 |
+
gin_channels=0,
|
| 427 |
+
):
|
| 428 |
+
super(Generator, self).__init__()
|
| 429 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 430 |
+
self.num_upsamples = len(upsample_rates)
|
| 431 |
+
self.conv_pre = Conv1d(
|
| 432 |
+
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
| 433 |
+
)
|
| 434 |
+
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
| 435 |
+
|
| 436 |
+
self.ups = nn.ModuleList()
|
| 437 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
| 438 |
+
self.ups.append(
|
| 439 |
+
weight_norm(
|
| 440 |
+
ConvTranspose1d(
|
| 441 |
+
upsample_initial_channel // (2**i),
|
| 442 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
| 443 |
+
k,
|
| 444 |
+
u,
|
| 445 |
+
padding=(k - u) // 2,
|
| 446 |
+
)
|
| 447 |
+
)
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
self.resblocks = nn.ModuleList()
|
| 451 |
+
for i in range(len(self.ups)):
|
| 452 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
| 453 |
+
for j, (k, d) in enumerate(
|
| 454 |
+
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
| 455 |
+
):
|
| 456 |
+
self.resblocks.append(resblock(ch, k, d))
|
| 457 |
+
|
| 458 |
+
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
| 459 |
+
self.ups.apply(init_weights)
|
| 460 |
+
|
| 461 |
+
if gin_channels != 0:
|
| 462 |
+
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
| 463 |
+
|
| 464 |
+
def forward(self, x, g=None):
|
| 465 |
+
x = self.conv_pre(x)
|
| 466 |
+
if g is not None:
|
| 467 |
+
x = x + self.cond(g)
|
| 468 |
+
|
| 469 |
+
for i in range(self.num_upsamples):
|
| 470 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
| 471 |
+
x = self.ups[i](x)
|
| 472 |
+
xs = None
|
| 473 |
+
for j in range(self.num_kernels):
|
| 474 |
+
if xs is None:
|
| 475 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 476 |
+
else:
|
| 477 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 478 |
+
x = xs / self.num_kernels
|
| 479 |
+
x = F.leaky_relu(x)
|
| 480 |
+
x = self.conv_post(x)
|
| 481 |
+
x = torch.tanh(x)
|
| 482 |
+
|
| 483 |
+
return x
|
| 484 |
+
|
| 485 |
+
def remove_weight_norm(self):
|
| 486 |
+
print("Removing weight norm...")
|
| 487 |
+
for l in self.ups:
|
| 488 |
+
remove_weight_norm(l)
|
| 489 |
+
for l in self.resblocks:
|
| 490 |
+
l.remove_weight_norm()
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
class DiscriminatorP(torch.nn.Module):
|
| 494 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
| 495 |
+
super(DiscriminatorP, self).__init__()
|
| 496 |
+
self.period = period
|
| 497 |
+
self.use_spectral_norm = use_spectral_norm
|
| 498 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
| 499 |
+
self.convs = nn.ModuleList(
|
| 500 |
+
[
|
| 501 |
+
norm_f(
|
| 502 |
+
Conv2d(
|
| 503 |
+
1,
|
| 504 |
+
32,
|
| 505 |
+
(kernel_size, 1),
|
| 506 |
+
(stride, 1),
|
| 507 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 508 |
+
)
|
| 509 |
+
),
|
| 510 |
+
norm_f(
|
| 511 |
+
Conv2d(
|
| 512 |
+
32,
|
| 513 |
+
128,
|
| 514 |
+
(kernel_size, 1),
|
| 515 |
+
(stride, 1),
|
| 516 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 517 |
+
)
|
| 518 |
+
),
|
| 519 |
+
norm_f(
|
| 520 |
+
Conv2d(
|
| 521 |
+
128,
|
| 522 |
+
512,
|
| 523 |
+
(kernel_size, 1),
|
| 524 |
+
(stride, 1),
|
| 525 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 526 |
+
)
|
| 527 |
+
),
|
| 528 |
+
norm_f(
|
| 529 |
+
Conv2d(
|
| 530 |
+
512,
|
| 531 |
+
1024,
|
| 532 |
+
(kernel_size, 1),
|
| 533 |
+
(stride, 1),
|
| 534 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 535 |
+
)
|
| 536 |
+
),
|
| 537 |
+
norm_f(
|
| 538 |
+
Conv2d(
|
| 539 |
+
1024,
|
| 540 |
+
1024,
|
| 541 |
+
(kernel_size, 1),
|
| 542 |
+
1,
|
| 543 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 544 |
+
)
|
| 545 |
+
),
|
| 546 |
+
]
|
| 547 |
+
)
|
| 548 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
| 549 |
+
|
| 550 |
+
def forward(self, x):
|
| 551 |
+
fmap = []
|
| 552 |
+
|
| 553 |
+
# 1d to 2d
|
| 554 |
+
b, c, t = x.shape
|
| 555 |
+
if t % self.period != 0: # pad first
|
| 556 |
+
n_pad = self.period - (t % self.period)
|
| 557 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
| 558 |
+
t = t + n_pad
|
| 559 |
+
x = x.view(b, c, t // self.period, self.period)
|
| 560 |
+
|
| 561 |
+
for l in self.convs:
|
| 562 |
+
x = l(x)
|
| 563 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
| 564 |
+
fmap.append(x)
|
| 565 |
+
x = self.conv_post(x)
|
| 566 |
+
fmap.append(x)
|
| 567 |
+
x = torch.flatten(x, 1, -1)
|
| 568 |
+
|
| 569 |
+
return x, fmap
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
class DiscriminatorS(torch.nn.Module):
|
| 573 |
+
def __init__(self, use_spectral_norm=False):
|
| 574 |
+
super(DiscriminatorS, self).__init__()
|
| 575 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
| 576 |
+
self.convs = nn.ModuleList(
|
| 577 |
+
[
|
| 578 |
+
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
| 579 |
+
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
| 580 |
+
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
| 581 |
+
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
| 582 |
+
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
| 583 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
| 584 |
+
]
|
| 585 |
+
)
|
| 586 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
| 587 |
+
|
| 588 |
+
def forward(self, x):
|
| 589 |
+
fmap = []
|
| 590 |
+
|
| 591 |
+
for l in self.convs:
|
| 592 |
+
x = l(x)
|
| 593 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
| 594 |
+
fmap.append(x)
|
| 595 |
+
x = self.conv_post(x)
|
| 596 |
+
fmap.append(x)
|
| 597 |
+
x = torch.flatten(x, 1, -1)
|
| 598 |
+
|
| 599 |
+
return x, fmap
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
| 603 |
+
def __init__(self, use_spectral_norm=False):
|
| 604 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
| 605 |
+
periods = [2, 3, 5, 7, 11]
|
| 606 |
+
|
| 607 |
+
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
| 608 |
+
discs = discs + [
|
| 609 |
+
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
|
| 610 |
+
]
|
| 611 |
+
self.discriminators = nn.ModuleList(discs)
|
| 612 |
+
|
| 613 |
+
def forward(self, y, y_hat):
|
| 614 |
+
y_d_rs = []
|
| 615 |
+
y_d_gs = []
|
| 616 |
+
fmap_rs = []
|
| 617 |
+
fmap_gs = []
|
| 618 |
+
for i, d in enumerate(self.discriminators):
|
| 619 |
+
y_d_r, fmap_r = d(y)
|
| 620 |
+
y_d_g, fmap_g = d(y_hat)
|
| 621 |
+
y_d_rs.append(y_d_r)
|
| 622 |
+
y_d_gs.append(y_d_g)
|
| 623 |
+
fmap_rs.append(fmap_r)
|
| 624 |
+
fmap_gs.append(fmap_g)
|
| 625 |
+
|
| 626 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
class ReferenceEncoder(nn.Module):
|
| 630 |
+
"""
|
| 631 |
+
inputs --- [N, Ty/r, n_mels*r] mels
|
| 632 |
+
outputs --- [N, ref_enc_gru_size]
|
| 633 |
+
"""
|
| 634 |
+
|
| 635 |
+
def __init__(self, spec_channels, gin_channels=0):
|
| 636 |
+
super().__init__()
|
| 637 |
+
self.spec_channels = spec_channels
|
| 638 |
+
ref_enc_filters = [32, 32, 64, 64, 128, 128]
|
| 639 |
+
K = len(ref_enc_filters)
|
| 640 |
+
filters = [1] + ref_enc_filters
|
| 641 |
+
convs = [
|
| 642 |
+
weight_norm(
|
| 643 |
+
nn.Conv2d(
|
| 644 |
+
in_channels=filters[i],
|
| 645 |
+
out_channels=filters[i + 1],
|
| 646 |
+
kernel_size=(3, 3),
|
| 647 |
+
stride=(2, 2),
|
| 648 |
+
padding=(1, 1),
|
| 649 |
+
)
|
| 650 |
+
)
|
| 651 |
+
for i in range(K)
|
| 652 |
+
]
|
| 653 |
+
self.convs = nn.ModuleList(convs)
|
| 654 |
+
# self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
|
| 655 |
+
|
| 656 |
+
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
|
| 657 |
+
self.gru = nn.GRU(
|
| 658 |
+
input_size=ref_enc_filters[-1] * out_channels,
|
| 659 |
+
hidden_size=256 // 2,
|
| 660 |
+
batch_first=True,
|
| 661 |
+
)
|
| 662 |
+
self.proj = nn.Linear(128, gin_channels)
|
| 663 |
+
|
| 664 |
+
def forward(self, inputs):
|
| 665 |
+
N = inputs.size(0)
|
| 666 |
+
out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
|
| 667 |
+
for conv in self.convs:
|
| 668 |
+
out = conv(out)
|
| 669 |
+
# out = wn(out)
|
| 670 |
+
out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
|
| 671 |
+
|
| 672 |
+
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
|
| 673 |
+
T = out.size(1)
|
| 674 |
+
N = out.size(0)
|
| 675 |
+
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
|
| 676 |
+
|
| 677 |
+
self.gru.flatten_parameters()
|
| 678 |
+
memory, out = self.gru(out) # out --- [1, N, 128]
|
| 679 |
+
|
| 680 |
+
return self.proj(out.squeeze(0)).unsqueeze(-1)
|
| 681 |
+
|
| 682 |
+
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
|
| 683 |
+
for i in range(n_convs):
|
| 684 |
+
L = (L - kernel_size + 2 * pad) // stride + 1
|
| 685 |
+
return L
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
class Quantizer_module(torch.nn.Module):
|
| 689 |
+
def __init__(self, n_e, e_dim):
|
| 690 |
+
super(Quantizer_module, self).__init__()
|
| 691 |
+
self.embedding = nn.Embedding(n_e, e_dim)
|
| 692 |
+
self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e)
|
| 693 |
+
|
| 694 |
+
def forward(self, x):
|
| 695 |
+
d = (
|
| 696 |
+
torch.sum(x**2, 1, keepdim=True)
|
| 697 |
+
+ torch.sum(self.embedding.weight**2, 1)
|
| 698 |
+
- 2 * torch.matmul(x, self.embedding.weight.T)
|
| 699 |
+
)
|
| 700 |
+
min_indicies = torch.argmin(d, 1)
|
| 701 |
+
z_q = self.embedding(min_indicies)
|
| 702 |
+
return z_q, min_indicies
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
class Quantizer(torch.nn.Module):
|
| 706 |
+
def __init__(self, embed_dim=512, n_code_groups=4, n_codes=160):
|
| 707 |
+
super(Quantizer, self).__init__()
|
| 708 |
+
assert embed_dim % n_code_groups == 0
|
| 709 |
+
self.quantizer_modules = nn.ModuleList(
|
| 710 |
+
[
|
| 711 |
+
Quantizer_module(n_codes, embed_dim // n_code_groups)
|
| 712 |
+
for _ in range(n_code_groups)
|
| 713 |
+
]
|
| 714 |
+
)
|
| 715 |
+
self.n_code_groups = n_code_groups
|
| 716 |
+
self.embed_dim = embed_dim
|
| 717 |
+
|
| 718 |
+
def forward(self, xin):
|
| 719 |
+
# B, C, T
|
| 720 |
+
B, C, T = xin.shape
|
| 721 |
+
xin = xin.transpose(1, 2)
|
| 722 |
+
x = xin.reshape(-1, self.embed_dim)
|
| 723 |
+
x = torch.split(x, self.embed_dim // self.n_code_groups, dim=-1)
|
| 724 |
+
min_indicies = []
|
| 725 |
+
z_q = []
|
| 726 |
+
for _x, m in zip(x, self.quantizer_modules):
|
| 727 |
+
_z_q, _min_indicies = m(_x)
|
| 728 |
+
z_q.append(_z_q)
|
| 729 |
+
min_indicies.append(_min_indicies) # B * T,
|
| 730 |
+
z_q = torch.cat(z_q, -1).reshape(xin.shape)
|
| 731 |
+
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
|
| 732 |
+
(z_q - xin.detach()) ** 2
|
| 733 |
+
)
|
| 734 |
+
z_q = xin + (z_q - xin).detach()
|
| 735 |
+
z_q = z_q.transpose(1, 2)
|
| 736 |
+
codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
|
| 737 |
+
return z_q, loss, codes.transpose(1, 2)
|
| 738 |
+
|
| 739 |
+
def embed(self, x):
|
| 740 |
+
# idx: N, 4, T
|
| 741 |
+
x = x.transpose(1, 2)
|
| 742 |
+
x = torch.split(x, 1, 2)
|
| 743 |
+
ret = []
|
| 744 |
+
for q, embed in zip(x, self.quantizer_modules):
|
| 745 |
+
q = embed.embedding(q.squeeze(-1))
|
| 746 |
+
ret.append(q)
|
| 747 |
+
ret = torch.cat(ret, -1)
|
| 748 |
+
return ret.transpose(1, 2) # N, C, T
|
| 749 |
+
|
| 750 |
+
|
| 751 |
+
class CodePredictor(nn.Module):
|
| 752 |
+
def __init__(
|
| 753 |
+
self,
|
| 754 |
+
hidden_channels,
|
| 755 |
+
filter_channels,
|
| 756 |
+
n_heads,
|
| 757 |
+
n_layers,
|
| 758 |
+
kernel_size,
|
| 759 |
+
p_dropout,
|
| 760 |
+
n_q=8,
|
| 761 |
+
dims=1024,
|
| 762 |
+
ssl_dim=768,
|
| 763 |
+
):
|
| 764 |
+
super().__init__()
|
| 765 |
+
self.hidden_channels = hidden_channels
|
| 766 |
+
self.filter_channels = filter_channels
|
| 767 |
+
self.n_heads = n_heads
|
| 768 |
+
self.n_layers = n_layers
|
| 769 |
+
self.kernel_size = kernel_size
|
| 770 |
+
self.p_dropout = p_dropout
|
| 771 |
+
|
| 772 |
+
self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
|
| 773 |
+
self.ref_enc = modules.MelStyleEncoder(
|
| 774 |
+
ssl_dim, style_vector_dim=hidden_channels
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
self.encoder = attentions.Encoder(
|
| 778 |
+
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
|
| 782 |
+
self.n_q = n_q
|
| 783 |
+
self.dims = dims
|
| 784 |
+
|
| 785 |
+
def forward(self, x, x_mask, refer, codes, infer=False):
|
| 786 |
+
x = x.detach()
|
| 787 |
+
x = self.vq_proj(x * x_mask) * x_mask
|
| 788 |
+
g = self.ref_enc(refer, x_mask)
|
| 789 |
+
x = x + g
|
| 790 |
+
x = self.encoder(x * x_mask, x_mask)
|
| 791 |
+
x = self.out_proj(x * x_mask) * x_mask
|
| 792 |
+
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
|
| 793 |
+
2, 3
|
| 794 |
+
)
|
| 795 |
+
target = codes[1:].transpose(0, 1)
|
| 796 |
+
if not infer:
|
| 797 |
+
logits = logits.reshape(-1, self.dims)
|
| 798 |
+
target = target.reshape(-1)
|
| 799 |
+
loss = torch.nn.functional.cross_entropy(logits, target)
|
| 800 |
+
return loss
|
| 801 |
+
else:
|
| 802 |
+
_, top10_preds = torch.topk(logits, 10, dim=-1)
|
| 803 |
+
correct_top10 = torch.any(top10_preds == target.unsqueeze(-1), dim=-1)
|
| 804 |
+
top3_acc = 100 * torch.mean(correct_top10.float()).detach().cpu().item()
|
| 805 |
+
|
| 806 |
+
print("Top-10 Accuracy:", top3_acc, "%")
|
| 807 |
+
|
| 808 |
+
pred_codes = torch.argmax(logits, dim=-1)
|
| 809 |
+
acc = 100 * torch.mean((pred_codes == target).float()).detach().cpu().item()
|
| 810 |
+
print("Top-1 Accuracy:", acc, "%")
|
| 811 |
+
|
| 812 |
+
return pred_codes.transpose(0, 1)
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
class SynthesizerTrn(nn.Module):
|
| 816 |
+
"""
|
| 817 |
+
Synthesizer for Training
|
| 818 |
+
"""
|
| 819 |
+
|
| 820 |
+
def __init__(
|
| 821 |
+
self,
|
| 822 |
+
spec_channels,
|
| 823 |
+
segment_size,
|
| 824 |
+
inter_channels,
|
| 825 |
+
hidden_channels,
|
| 826 |
+
filter_channels,
|
| 827 |
+
n_heads,
|
| 828 |
+
n_layers,
|
| 829 |
+
kernel_size,
|
| 830 |
+
p_dropout,
|
| 831 |
+
resblock,
|
| 832 |
+
resblock_kernel_sizes,
|
| 833 |
+
resblock_dilation_sizes,
|
| 834 |
+
upsample_rates,
|
| 835 |
+
upsample_initial_channel,
|
| 836 |
+
upsample_kernel_sizes,
|
| 837 |
+
n_speakers=0,
|
| 838 |
+
gin_channels=0,
|
| 839 |
+
use_sdp=True,
|
| 840 |
+
semantic_frame_rate=None,
|
| 841 |
+
freeze_quantizer=None,
|
| 842 |
+
version = "v2",
|
| 843 |
+
**kwargs
|
| 844 |
+
):
|
| 845 |
+
super().__init__()
|
| 846 |
+
self.spec_channels = spec_channels
|
| 847 |
+
self.inter_channels = inter_channels
|
| 848 |
+
self.hidden_channels = hidden_channels
|
| 849 |
+
self.filter_channels = filter_channels
|
| 850 |
+
self.n_heads = n_heads
|
| 851 |
+
self.n_layers = n_layers
|
| 852 |
+
self.kernel_size = kernel_size
|
| 853 |
+
self.p_dropout = p_dropout
|
| 854 |
+
self.resblock = resblock
|
| 855 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
| 856 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
| 857 |
+
self.upsample_rates = upsample_rates
|
| 858 |
+
self.upsample_initial_channel = upsample_initial_channel
|
| 859 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
| 860 |
+
self.segment_size = segment_size
|
| 861 |
+
self.n_speakers = n_speakers
|
| 862 |
+
self.gin_channels = gin_channels
|
| 863 |
+
self.version = version
|
| 864 |
+
|
| 865 |
+
self.use_sdp = use_sdp
|
| 866 |
+
self.enc_p = TextEncoder(
|
| 867 |
+
inter_channels,
|
| 868 |
+
hidden_channels,
|
| 869 |
+
filter_channels,
|
| 870 |
+
n_heads,
|
| 871 |
+
n_layers,
|
| 872 |
+
kernel_size,
|
| 873 |
+
p_dropout,
|
| 874 |
+
version = version,
|
| 875 |
+
)
|
| 876 |
+
self.dec = Generator(
|
| 877 |
+
inter_channels,
|
| 878 |
+
resblock,
|
| 879 |
+
resblock_kernel_sizes,
|
| 880 |
+
resblock_dilation_sizes,
|
| 881 |
+
upsample_rates,
|
| 882 |
+
upsample_initial_channel,
|
| 883 |
+
upsample_kernel_sizes,
|
| 884 |
+
gin_channels=gin_channels,
|
| 885 |
+
)
|
| 886 |
+
self.enc_q = PosteriorEncoder(
|
| 887 |
+
spec_channels,
|
| 888 |
+
inter_channels,
|
| 889 |
+
hidden_channels,
|
| 890 |
+
5,
|
| 891 |
+
1,
|
| 892 |
+
16,
|
| 893 |
+
gin_channels=gin_channels,
|
| 894 |
+
)
|
| 895 |
+
self.flow = ResidualCouplingBlock(
|
| 896 |
+
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
# self.version=os.environ.get("version","v1")
|
| 900 |
+
if(self.version=="v1"):
|
| 901 |
+
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
|
| 902 |
+
else:
|
| 903 |
+
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)
|
| 904 |
+
|
| 905 |
+
ssl_dim = 768
|
| 906 |
+
assert semantic_frame_rate in ["25hz", "50hz"]
|
| 907 |
+
self.semantic_frame_rate = semantic_frame_rate
|
| 908 |
+
if semantic_frame_rate == "25hz":
|
| 909 |
+
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
|
| 910 |
+
else:
|
| 911 |
+
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
|
| 912 |
+
|
| 913 |
+
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
| 914 |
+
self.freeze_quantizer = freeze_quantizer
|
| 915 |
+
self.sv_emb = nn.Linear(20480, gin_channels)
|
| 916 |
+
self.ge_to512 = nn.Linear(gin_channels, 512)
|
| 917 |
+
self.prelu = nn.PReLU(num_parameters=gin_channels)
|
| 918 |
+
|
| 919 |
+
def forward(self, ssl, y, y_lengths, text, text_lengths):
|
| 920 |
+
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
|
| 921 |
+
y.dtype
|
| 922 |
+
)
|
| 923 |
+
if(self.version=="v1"):
|
| 924 |
+
ge = self.ref_enc(y * y_mask, y_mask)
|
| 925 |
+
else:
|
| 926 |
+
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
|
| 927 |
+
sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
|
| 928 |
+
ge += sv_emb.unsqueeze(-1)
|
| 929 |
+
ge = self.prelu(ge)
|
| 930 |
+
ge512 = self.ge_to512(ge.transpose(2, 1)).transpose(2, 1)
|
| 931 |
+
with autocast(enabled=False):
|
| 932 |
+
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
|
| 933 |
+
with maybe_no_grad:
|
| 934 |
+
if self.freeze_quantizer:
|
| 935 |
+
self.ssl_proj.eval()
|
| 936 |
+
self.quantizer.eval()
|
| 937 |
+
ssl = self.ssl_proj(ssl)
|
| 938 |
+
quantized, codes, commit_loss, quantized_list = self.quantizer(
|
| 939 |
+
ssl, layers=[0]
|
| 940 |
+
)
|
| 941 |
+
|
| 942 |
+
if self.semantic_frame_rate == "25hz":
|
| 943 |
+
quantized = F.interpolate(
|
| 944 |
+
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
|
| 945 |
+
)
|
| 946 |
+
|
| 947 |
+
x, m_p, logs_p, y_mask = self.enc_p(
|
| 948 |
+
quantized, y_lengths, text, text_lengths, ge512
|
| 949 |
+
)
|
| 950 |
+
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
|
| 951 |
+
z_p = self.flow(z, y_mask, g=ge)
|
| 952 |
+
|
| 953 |
+
z_slice, ids_slice = commons.rand_slice_segments(
|
| 954 |
+
z, y_lengths, self.segment_size
|
| 955 |
+
)
|
| 956 |
+
o = self.dec(z_slice, g=ge)
|
| 957 |
+
return (
|
| 958 |
+
o,
|
| 959 |
+
commit_loss,
|
| 960 |
+
ids_slice,
|
| 961 |
+
y_mask,
|
| 962 |
+
y_mask,
|
| 963 |
+
(z, z_p, m_p, logs_p, m_q, logs_q),
|
| 964 |
+
quantized,
|
| 965 |
+
)
|
| 966 |
+
|
| 967 |
+
def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5):
|
| 968 |
+
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
|
| 969 |
+
y.dtype
|
| 970 |
+
)
|
| 971 |
+
if(self.version=="v1"):
|
| 972 |
+
ge = self.ref_enc(y * y_mask, y_mask)
|
| 973 |
+
else:
|
| 974 |
+
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
|
| 975 |
+
|
| 976 |
+
ssl = self.ssl_proj(ssl)
|
| 977 |
+
quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
|
| 978 |
+
if self.semantic_frame_rate == "25hz":
|
| 979 |
+
quantized = F.interpolate(
|
| 980 |
+
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
|
| 981 |
+
)
|
| 982 |
+
|
| 983 |
+
x, m_p, logs_p, y_mask = self.enc_p(
|
| 984 |
+
quantized, y_lengths, text, text_lengths, ge, test=test
|
| 985 |
+
)
|
| 986 |
+
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
| 987 |
+
|
| 988 |
+
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
| 989 |
+
|
| 990 |
+
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
| 991 |
+
return o, y_mask, (z, z_p, m_p, logs_p)
|
| 992 |
+
|
| 993 |
+
@torch.no_grad()
|
| 994 |
+
def decode(self, codes, text, refer, noise_scale=0.5,speed=1, sv_emb=None):
|
| 995 |
+
def get_ge(refer, sv_emb):
|
| 996 |
+
ge = None
|
| 997 |
+
if refer is not None:
|
| 998 |
+
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
| 999 |
+
refer_mask = torch.unsqueeze(
|
| 1000 |
+
commons.sequence_mask(refer_lengths, refer.size(2)), 1
|
| 1001 |
+
).to(refer.dtype)
|
| 1002 |
+
if (self.version == "v1"):
|
| 1003 |
+
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
| 1004 |
+
else:
|
| 1005 |
+
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
| 1006 |
+
sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
|
| 1007 |
+
ge += sv_emb.unsqueeze(-1)
|
| 1008 |
+
ge = self.prelu(ge)
|
| 1009 |
+
return ge
|
| 1010 |
+
if(type(refer)==list):
|
| 1011 |
+
ges=[]
|
| 1012 |
+
for idx,_refer in enumerate(refer):
|
| 1013 |
+
ge=get_ge(_refer,sv_emb[idx])
|
| 1014 |
+
ges.append(ge)
|
| 1015 |
+
ge=torch.stack(ges,0).mean(0)
|
| 1016 |
+
else:
|
| 1017 |
+
ge = get_ge(refer, sv_emb)
|
| 1018 |
+
|
| 1019 |
+
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
| 1020 |
+
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
| 1021 |
+
|
| 1022 |
+
quantized = self.quantizer.decode(codes)
|
| 1023 |
+
if self.semantic_frame_rate == "25hz":
|
| 1024 |
+
quantized = F.interpolate(
|
| 1025 |
+
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
|
| 1026 |
+
)
|
| 1027 |
+
x, m_p, logs_p, y_mask = self.enc_p(
|
| 1028 |
+
quantized, y_lengths, text, text_lengths, self.ge_to512(ge.transpose(2,1)).transpose(2,1),speed
|
| 1029 |
+
)
|
| 1030 |
+
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
| 1031 |
+
|
| 1032 |
+
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
| 1033 |
+
|
| 1034 |
+
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
| 1035 |
+
return o
|
| 1036 |
+
|
| 1037 |
+
def extract_latent(self, x):
|
| 1038 |
+
ssl = self.ssl_proj(x)
|
| 1039 |
+
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
| 1040 |
+
return codes.transpose(0, 1)
|
module/models_onnx.py
ADDED
|
@@ -0,0 +1,918 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
from module import commons
|
| 8 |
+
from module import modules
|
| 9 |
+
from module import attentions_onnx as attentions
|
| 10 |
+
|
| 11 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
| 12 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
| 13 |
+
from module.commons import init_weights, get_padding
|
| 14 |
+
from module.mrte_model import MRTE
|
| 15 |
+
from module.quantize import ResidualVectorQuantizer
|
| 16 |
+
from text import symbols
|
| 17 |
+
from torch.cuda.amp import autocast
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class StochasticDurationPredictor(nn.Module):
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
in_channels,
|
| 24 |
+
filter_channels,
|
| 25 |
+
kernel_size,
|
| 26 |
+
p_dropout,
|
| 27 |
+
n_flows=4,
|
| 28 |
+
gin_channels=0,
|
| 29 |
+
):
|
| 30 |
+
super().__init__()
|
| 31 |
+
filter_channels = in_channels # it needs to be removed from future version.
|
| 32 |
+
self.in_channels = in_channels
|
| 33 |
+
self.filter_channels = filter_channels
|
| 34 |
+
self.kernel_size = kernel_size
|
| 35 |
+
self.p_dropout = p_dropout
|
| 36 |
+
self.n_flows = n_flows
|
| 37 |
+
self.gin_channels = gin_channels
|
| 38 |
+
|
| 39 |
+
self.log_flow = modules.Log()
|
| 40 |
+
self.flows = nn.ModuleList()
|
| 41 |
+
self.flows.append(modules.ElementwiseAffine(2))
|
| 42 |
+
for i in range(n_flows):
|
| 43 |
+
self.flows.append(
|
| 44 |
+
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
| 45 |
+
)
|
| 46 |
+
self.flows.append(modules.Flip())
|
| 47 |
+
|
| 48 |
+
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
| 49 |
+
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
| 50 |
+
self.post_convs = modules.DDSConv(
|
| 51 |
+
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
| 52 |
+
)
|
| 53 |
+
self.post_flows = nn.ModuleList()
|
| 54 |
+
self.post_flows.append(modules.ElementwiseAffine(2))
|
| 55 |
+
for i in range(4):
|
| 56 |
+
self.post_flows.append(
|
| 57 |
+
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
| 58 |
+
)
|
| 59 |
+
self.post_flows.append(modules.Flip())
|
| 60 |
+
|
| 61 |
+
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
| 62 |
+
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
| 63 |
+
self.convs = modules.DDSConv(
|
| 64 |
+
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
| 65 |
+
)
|
| 66 |
+
if gin_channels != 0:
|
| 67 |
+
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
| 68 |
+
|
| 69 |
+
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
|
| 70 |
+
x = torch.detach(x)
|
| 71 |
+
x = self.pre(x)
|
| 72 |
+
if g is not None:
|
| 73 |
+
g = torch.detach(g)
|
| 74 |
+
x = x + self.cond(g)
|
| 75 |
+
x = self.convs(x, x_mask)
|
| 76 |
+
x = self.proj(x) * x_mask
|
| 77 |
+
|
| 78 |
+
if not reverse:
|
| 79 |
+
flows = self.flows
|
| 80 |
+
assert w is not None
|
| 81 |
+
|
| 82 |
+
logdet_tot_q = 0
|
| 83 |
+
h_w = self.post_pre(w)
|
| 84 |
+
h_w = self.post_convs(h_w, x_mask)
|
| 85 |
+
h_w = self.post_proj(h_w) * x_mask
|
| 86 |
+
e_q = (
|
| 87 |
+
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
|
| 88 |
+
* x_mask
|
| 89 |
+
)
|
| 90 |
+
z_q = e_q
|
| 91 |
+
for flow in self.post_flows:
|
| 92 |
+
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
| 93 |
+
logdet_tot_q += logdet_q
|
| 94 |
+
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
| 95 |
+
u = torch.sigmoid(z_u) * x_mask
|
| 96 |
+
z0 = (w - u) * x_mask
|
| 97 |
+
logdet_tot_q += torch.sum(
|
| 98 |
+
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
|
| 99 |
+
)
|
| 100 |
+
logq = (
|
| 101 |
+
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
|
| 102 |
+
- logdet_tot_q
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
logdet_tot = 0
|
| 106 |
+
z0, logdet = self.log_flow(z0, x_mask)
|
| 107 |
+
logdet_tot += logdet
|
| 108 |
+
z = torch.cat([z0, z1], 1)
|
| 109 |
+
for flow in flows:
|
| 110 |
+
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
| 111 |
+
logdet_tot = logdet_tot + logdet
|
| 112 |
+
nll = (
|
| 113 |
+
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
|
| 114 |
+
- logdet_tot
|
| 115 |
+
)
|
| 116 |
+
return nll + logq # [b]
|
| 117 |
+
else:
|
| 118 |
+
flows = list(reversed(self.flows))
|
| 119 |
+
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
| 120 |
+
z = (
|
| 121 |
+
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
|
| 122 |
+
* noise_scale
|
| 123 |
+
)
|
| 124 |
+
for flow in flows:
|
| 125 |
+
z = flow(z, x_mask, g=x, reverse=reverse)
|
| 126 |
+
z0, z1 = torch.split(z, [1, 1], 1)
|
| 127 |
+
logw = z0
|
| 128 |
+
return logw
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class DurationPredictor(nn.Module):
|
| 132 |
+
def __init__(
|
| 133 |
+
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
|
| 134 |
+
):
|
| 135 |
+
super().__init__()
|
| 136 |
+
|
| 137 |
+
self.in_channels = in_channels
|
| 138 |
+
self.filter_channels = filter_channels
|
| 139 |
+
self.kernel_size = kernel_size
|
| 140 |
+
self.p_dropout = p_dropout
|
| 141 |
+
self.gin_channels = gin_channels
|
| 142 |
+
|
| 143 |
+
self.drop = nn.Dropout(p_dropout)
|
| 144 |
+
self.conv_1 = nn.Conv1d(
|
| 145 |
+
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 146 |
+
)
|
| 147 |
+
self.norm_1 = modules.LayerNorm(filter_channels)
|
| 148 |
+
self.conv_2 = nn.Conv1d(
|
| 149 |
+
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 150 |
+
)
|
| 151 |
+
self.norm_2 = modules.LayerNorm(filter_channels)
|
| 152 |
+
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
| 153 |
+
|
| 154 |
+
if gin_channels != 0:
|
| 155 |
+
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
| 156 |
+
|
| 157 |
+
def forward(self, x, x_mask, g=None):
|
| 158 |
+
x = torch.detach(x)
|
| 159 |
+
if g is not None:
|
| 160 |
+
g = torch.detach(g)
|
| 161 |
+
x = x + self.cond(g)
|
| 162 |
+
x = self.conv_1(x * x_mask)
|
| 163 |
+
x = torch.relu(x)
|
| 164 |
+
x = self.norm_1(x)
|
| 165 |
+
x = self.drop(x)
|
| 166 |
+
x = self.conv_2(x * x_mask)
|
| 167 |
+
x = torch.relu(x)
|
| 168 |
+
x = self.norm_2(x)
|
| 169 |
+
x = self.drop(x)
|
| 170 |
+
x = self.proj(x * x_mask)
|
| 171 |
+
return x * x_mask
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class TextEncoder(nn.Module):
|
| 175 |
+
def __init__(
|
| 176 |
+
self,
|
| 177 |
+
out_channels,
|
| 178 |
+
hidden_channels,
|
| 179 |
+
filter_channels,
|
| 180 |
+
n_heads,
|
| 181 |
+
n_layers,
|
| 182 |
+
kernel_size,
|
| 183 |
+
p_dropout,
|
| 184 |
+
latent_channels=192,
|
| 185 |
+
):
|
| 186 |
+
super().__init__()
|
| 187 |
+
self.out_channels = out_channels
|
| 188 |
+
self.hidden_channels = hidden_channels
|
| 189 |
+
self.filter_channels = filter_channels
|
| 190 |
+
self.n_heads = n_heads
|
| 191 |
+
self.n_layers = n_layers
|
| 192 |
+
self.kernel_size = kernel_size
|
| 193 |
+
self.p_dropout = p_dropout
|
| 194 |
+
self.latent_channels = latent_channels
|
| 195 |
+
|
| 196 |
+
self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
|
| 197 |
+
|
| 198 |
+
self.encoder_ssl = attentions.Encoder(
|
| 199 |
+
hidden_channels,
|
| 200 |
+
filter_channels,
|
| 201 |
+
n_heads,
|
| 202 |
+
n_layers // 2,
|
| 203 |
+
kernel_size,
|
| 204 |
+
p_dropout,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
self.encoder_text = attentions.Encoder(
|
| 208 |
+
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
| 209 |
+
)
|
| 210 |
+
self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
|
| 211 |
+
|
| 212 |
+
self.mrte = MRTE()
|
| 213 |
+
|
| 214 |
+
self.encoder2 = attentions.Encoder(
|
| 215 |
+
hidden_channels,
|
| 216 |
+
filter_channels,
|
| 217 |
+
n_heads,
|
| 218 |
+
n_layers // 2,
|
| 219 |
+
kernel_size,
|
| 220 |
+
p_dropout,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
| 224 |
+
|
| 225 |
+
def forward(self, y, text, ge):
|
| 226 |
+
y_mask = torch.ones_like(y[:1,:1,:])
|
| 227 |
+
|
| 228 |
+
y = self.ssl_proj(y * y_mask) * y_mask
|
| 229 |
+
y = self.encoder_ssl(y * y_mask, y_mask)
|
| 230 |
+
|
| 231 |
+
text_mask = torch.ones_like(text).to(y.dtype).unsqueeze(0)
|
| 232 |
+
|
| 233 |
+
text = self.text_embedding(text).transpose(1, 2)
|
| 234 |
+
text = self.encoder_text(text * text_mask, text_mask)
|
| 235 |
+
y = self.mrte(y, y_mask, text, text_mask, ge)
|
| 236 |
+
|
| 237 |
+
y = self.encoder2(y * y_mask, y_mask)
|
| 238 |
+
|
| 239 |
+
stats = self.proj(y) * y_mask
|
| 240 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
| 241 |
+
return y, m, logs, y_mask
|
| 242 |
+
|
| 243 |
+
def extract_latent(self, x):
|
| 244 |
+
x = self.ssl_proj(x)
|
| 245 |
+
quantized, codes, commit_loss, quantized_list = self.quantizer(x)
|
| 246 |
+
return codes.transpose(0, 1)
|
| 247 |
+
|
| 248 |
+
def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
|
| 249 |
+
quantized = self.quantizer.decode(codes)
|
| 250 |
+
|
| 251 |
+
y = self.vq_proj(quantized) * y_mask
|
| 252 |
+
y = self.encoder_ssl(y * y_mask, y_mask)
|
| 253 |
+
|
| 254 |
+
y = self.mrte(y, y_mask, refer, refer_mask, ge)
|
| 255 |
+
|
| 256 |
+
y = self.encoder2(y * y_mask, y_mask)
|
| 257 |
+
|
| 258 |
+
stats = self.proj(y) * y_mask
|
| 259 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
| 260 |
+
return y, m, logs, y_mask, quantized
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class ResidualCouplingBlock(nn.Module):
|
| 264 |
+
def __init__(
|
| 265 |
+
self,
|
| 266 |
+
channels,
|
| 267 |
+
hidden_channels,
|
| 268 |
+
kernel_size,
|
| 269 |
+
dilation_rate,
|
| 270 |
+
n_layers,
|
| 271 |
+
n_flows=4,
|
| 272 |
+
gin_channels=0,
|
| 273 |
+
):
|
| 274 |
+
super().__init__()
|
| 275 |
+
self.channels = channels
|
| 276 |
+
self.hidden_channels = hidden_channels
|
| 277 |
+
self.kernel_size = kernel_size
|
| 278 |
+
self.dilation_rate = dilation_rate
|
| 279 |
+
self.n_layers = n_layers
|
| 280 |
+
self.n_flows = n_flows
|
| 281 |
+
self.gin_channels = gin_channels
|
| 282 |
+
|
| 283 |
+
self.flows = nn.ModuleList()
|
| 284 |
+
for i in range(n_flows):
|
| 285 |
+
self.flows.append(
|
| 286 |
+
modules.ResidualCouplingLayer(
|
| 287 |
+
channels,
|
| 288 |
+
hidden_channels,
|
| 289 |
+
kernel_size,
|
| 290 |
+
dilation_rate,
|
| 291 |
+
n_layers,
|
| 292 |
+
gin_channels=gin_channels,
|
| 293 |
+
mean_only=True,
|
| 294 |
+
)
|
| 295 |
+
)
|
| 296 |
+
self.flows.append(modules.Flip())
|
| 297 |
+
|
| 298 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
| 299 |
+
if not reverse:
|
| 300 |
+
for flow in self.flows:
|
| 301 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
| 302 |
+
else:
|
| 303 |
+
for flow in reversed(self.flows):
|
| 304 |
+
x = flow(x, x_mask, g=g, reverse=reverse)
|
| 305 |
+
return x
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
class PosteriorEncoder(nn.Module):
|
| 309 |
+
def __init__(
|
| 310 |
+
self,
|
| 311 |
+
in_channels,
|
| 312 |
+
out_channels,
|
| 313 |
+
hidden_channels,
|
| 314 |
+
kernel_size,
|
| 315 |
+
dilation_rate,
|
| 316 |
+
n_layers,
|
| 317 |
+
gin_channels=0,
|
| 318 |
+
):
|
| 319 |
+
super().__init__()
|
| 320 |
+
self.in_channels = in_channels
|
| 321 |
+
self.out_channels = out_channels
|
| 322 |
+
self.hidden_channels = hidden_channels
|
| 323 |
+
self.kernel_size = kernel_size
|
| 324 |
+
self.dilation_rate = dilation_rate
|
| 325 |
+
self.n_layers = n_layers
|
| 326 |
+
self.gin_channels = gin_channels
|
| 327 |
+
|
| 328 |
+
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
| 329 |
+
self.enc = modules.WN(
|
| 330 |
+
hidden_channels,
|
| 331 |
+
kernel_size,
|
| 332 |
+
dilation_rate,
|
| 333 |
+
n_layers,
|
| 334 |
+
gin_channels=gin_channels,
|
| 335 |
+
)
|
| 336 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
| 337 |
+
|
| 338 |
+
def forward(self, x, x_lengths, g=None):
|
| 339 |
+
if g != None:
|
| 340 |
+
g = g.detach()
|
| 341 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
| 342 |
+
x.dtype
|
| 343 |
+
)
|
| 344 |
+
x = self.pre(x) * x_mask
|
| 345 |
+
x = self.enc(x, x_mask, g=g)
|
| 346 |
+
stats = self.proj(x) * x_mask
|
| 347 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
| 348 |
+
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
| 349 |
+
return z, m, logs, x_mask
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class WNEncoder(nn.Module):
|
| 353 |
+
def __init__(
|
| 354 |
+
self,
|
| 355 |
+
in_channels,
|
| 356 |
+
out_channels,
|
| 357 |
+
hidden_channels,
|
| 358 |
+
kernel_size,
|
| 359 |
+
dilation_rate,
|
| 360 |
+
n_layers,
|
| 361 |
+
gin_channels=0,
|
| 362 |
+
):
|
| 363 |
+
super().__init__()
|
| 364 |
+
self.in_channels = in_channels
|
| 365 |
+
self.out_channels = out_channels
|
| 366 |
+
self.hidden_channels = hidden_channels
|
| 367 |
+
self.kernel_size = kernel_size
|
| 368 |
+
self.dilation_rate = dilation_rate
|
| 369 |
+
self.n_layers = n_layers
|
| 370 |
+
self.gin_channels = gin_channels
|
| 371 |
+
|
| 372 |
+
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
| 373 |
+
self.enc = modules.WN(
|
| 374 |
+
hidden_channels,
|
| 375 |
+
kernel_size,
|
| 376 |
+
dilation_rate,
|
| 377 |
+
n_layers,
|
| 378 |
+
gin_channels=gin_channels,
|
| 379 |
+
)
|
| 380 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
| 381 |
+
self.norm = modules.LayerNorm(out_channels)
|
| 382 |
+
|
| 383 |
+
def forward(self, x, x_lengths, g=None):
|
| 384 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
| 385 |
+
x.dtype
|
| 386 |
+
)
|
| 387 |
+
x = self.pre(x) * x_mask
|
| 388 |
+
x = self.enc(x, x_mask, g=g)
|
| 389 |
+
out = self.proj(x) * x_mask
|
| 390 |
+
out = self.norm(out)
|
| 391 |
+
return out
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
class Generator(torch.nn.Module):
|
| 395 |
+
def __init__(
|
| 396 |
+
self,
|
| 397 |
+
initial_channel,
|
| 398 |
+
resblock,
|
| 399 |
+
resblock_kernel_sizes,
|
| 400 |
+
resblock_dilation_sizes,
|
| 401 |
+
upsample_rates,
|
| 402 |
+
upsample_initial_channel,
|
| 403 |
+
upsample_kernel_sizes,
|
| 404 |
+
gin_channels=0,
|
| 405 |
+
):
|
| 406 |
+
super(Generator, self).__init__()
|
| 407 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 408 |
+
self.num_upsamples = len(upsample_rates)
|
| 409 |
+
self.conv_pre = Conv1d(
|
| 410 |
+
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
| 411 |
+
)
|
| 412 |
+
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
| 413 |
+
|
| 414 |
+
self.ups = nn.ModuleList()
|
| 415 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
| 416 |
+
self.ups.append(
|
| 417 |
+
weight_norm(
|
| 418 |
+
ConvTranspose1d(
|
| 419 |
+
upsample_initial_channel // (2**i),
|
| 420 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
| 421 |
+
k,
|
| 422 |
+
u,
|
| 423 |
+
padding=(k - u) // 2,
|
| 424 |
+
)
|
| 425 |
+
)
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
self.resblocks = nn.ModuleList()
|
| 429 |
+
for i in range(len(self.ups)):
|
| 430 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
| 431 |
+
for j, (k, d) in enumerate(
|
| 432 |
+
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
| 433 |
+
):
|
| 434 |
+
self.resblocks.append(resblock(ch, k, d))
|
| 435 |
+
|
| 436 |
+
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
| 437 |
+
self.ups.apply(init_weights)
|
| 438 |
+
|
| 439 |
+
if gin_channels != 0:
|
| 440 |
+
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
| 441 |
+
|
| 442 |
+
def forward(self, x, g=None):
|
| 443 |
+
x = self.conv_pre(x)
|
| 444 |
+
if g is not None:
|
| 445 |
+
x = x + self.cond(g)
|
| 446 |
+
|
| 447 |
+
for i in range(self.num_upsamples):
|
| 448 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
| 449 |
+
x = self.ups[i](x)
|
| 450 |
+
xs = None
|
| 451 |
+
for j in range(self.num_kernels):
|
| 452 |
+
if xs is None:
|
| 453 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 454 |
+
else:
|
| 455 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 456 |
+
x = xs / self.num_kernels
|
| 457 |
+
x = F.leaky_relu(x)
|
| 458 |
+
x = self.conv_post(x)
|
| 459 |
+
x = torch.tanh(x)
|
| 460 |
+
|
| 461 |
+
return x
|
| 462 |
+
|
| 463 |
+
def remove_weight_norm(self):
|
| 464 |
+
print("Removing weight norm...")
|
| 465 |
+
for l in self.ups:
|
| 466 |
+
remove_weight_norm(l)
|
| 467 |
+
for l in self.resblocks:
|
| 468 |
+
l.remove_weight_norm()
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
class DiscriminatorP(torch.nn.Module):
|
| 472 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
| 473 |
+
super(DiscriminatorP, self).__init__()
|
| 474 |
+
self.period = period
|
| 475 |
+
self.use_spectral_norm = use_spectral_norm
|
| 476 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
| 477 |
+
self.convs = nn.ModuleList(
|
| 478 |
+
[
|
| 479 |
+
norm_f(
|
| 480 |
+
Conv2d(
|
| 481 |
+
1,
|
| 482 |
+
32,
|
| 483 |
+
(kernel_size, 1),
|
| 484 |
+
(stride, 1),
|
| 485 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 486 |
+
)
|
| 487 |
+
),
|
| 488 |
+
norm_f(
|
| 489 |
+
Conv2d(
|
| 490 |
+
32,
|
| 491 |
+
128,
|
| 492 |
+
(kernel_size, 1),
|
| 493 |
+
(stride, 1),
|
| 494 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 495 |
+
)
|
| 496 |
+
),
|
| 497 |
+
norm_f(
|
| 498 |
+
Conv2d(
|
| 499 |
+
128,
|
| 500 |
+
512,
|
| 501 |
+
(kernel_size, 1),
|
| 502 |
+
(stride, 1),
|
| 503 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 504 |
+
)
|
| 505 |
+
),
|
| 506 |
+
norm_f(
|
| 507 |
+
Conv2d(
|
| 508 |
+
512,
|
| 509 |
+
1024,
|
| 510 |
+
(kernel_size, 1),
|
| 511 |
+
(stride, 1),
|
| 512 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 513 |
+
)
|
| 514 |
+
),
|
| 515 |
+
norm_f(
|
| 516 |
+
Conv2d(
|
| 517 |
+
1024,
|
| 518 |
+
1024,
|
| 519 |
+
(kernel_size, 1),
|
| 520 |
+
1,
|
| 521 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 522 |
+
)
|
| 523 |
+
),
|
| 524 |
+
]
|
| 525 |
+
)
|
| 526 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
| 527 |
+
|
| 528 |
+
def forward(self, x):
|
| 529 |
+
fmap = []
|
| 530 |
+
|
| 531 |
+
# 1d to 2d
|
| 532 |
+
b, c, t = x.shape
|
| 533 |
+
if t % self.period != 0: # pad first
|
| 534 |
+
n_pad = self.period - (t % self.period)
|
| 535 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
| 536 |
+
t = t + n_pad
|
| 537 |
+
x = x.view(b, c, t // self.period, self.period)
|
| 538 |
+
|
| 539 |
+
for l in self.convs:
|
| 540 |
+
x = l(x)
|
| 541 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
| 542 |
+
fmap.append(x)
|
| 543 |
+
x = self.conv_post(x)
|
| 544 |
+
fmap.append(x)
|
| 545 |
+
x = torch.flatten(x, 1, -1)
|
| 546 |
+
|
| 547 |
+
return x, fmap
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
class DiscriminatorS(torch.nn.Module):
|
| 551 |
+
def __init__(self, use_spectral_norm=False):
|
| 552 |
+
super(DiscriminatorS, self).__init__()
|
| 553 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
| 554 |
+
self.convs = nn.ModuleList(
|
| 555 |
+
[
|
| 556 |
+
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
| 557 |
+
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
| 558 |
+
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
| 559 |
+
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
| 560 |
+
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
| 561 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
| 562 |
+
]
|
| 563 |
+
)
|
| 564 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
| 565 |
+
|
| 566 |
+
def forward(self, x):
|
| 567 |
+
fmap = []
|
| 568 |
+
|
| 569 |
+
for l in self.convs:
|
| 570 |
+
x = l(x)
|
| 571 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
| 572 |
+
fmap.append(x)
|
| 573 |
+
x = self.conv_post(x)
|
| 574 |
+
fmap.append(x)
|
| 575 |
+
x = torch.flatten(x, 1, -1)
|
| 576 |
+
|
| 577 |
+
return x, fmap
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
| 581 |
+
def __init__(self, use_spectral_norm=False):
|
| 582 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
| 583 |
+
periods = [2, 3, 5, 7, 11]
|
| 584 |
+
|
| 585 |
+
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
| 586 |
+
discs = discs + [
|
| 587 |
+
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
|
| 588 |
+
]
|
| 589 |
+
self.discriminators = nn.ModuleList(discs)
|
| 590 |
+
|
| 591 |
+
def forward(self, y, y_hat):
|
| 592 |
+
y_d_rs = []
|
| 593 |
+
y_d_gs = []
|
| 594 |
+
fmap_rs = []
|
| 595 |
+
fmap_gs = []
|
| 596 |
+
for i, d in enumerate(self.discriminators):
|
| 597 |
+
y_d_r, fmap_r = d(y)
|
| 598 |
+
y_d_g, fmap_g = d(y_hat)
|
| 599 |
+
y_d_rs.append(y_d_r)
|
| 600 |
+
y_d_gs.append(y_d_g)
|
| 601 |
+
fmap_rs.append(fmap_r)
|
| 602 |
+
fmap_gs.append(fmap_g)
|
| 603 |
+
|
| 604 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
class ReferenceEncoder(nn.Module):
|
| 608 |
+
"""
|
| 609 |
+
inputs --- [N, Ty/r, n_mels*r] mels
|
| 610 |
+
outputs --- [N, ref_enc_gru_size]
|
| 611 |
+
"""
|
| 612 |
+
|
| 613 |
+
def __init__(self, spec_channels, gin_channels=0):
|
| 614 |
+
super().__init__()
|
| 615 |
+
self.spec_channels = spec_channels
|
| 616 |
+
ref_enc_filters = [32, 32, 64, 64, 128, 128]
|
| 617 |
+
K = len(ref_enc_filters)
|
| 618 |
+
filters = [1] + ref_enc_filters
|
| 619 |
+
convs = [
|
| 620 |
+
weight_norm(
|
| 621 |
+
nn.Conv2d(
|
| 622 |
+
in_channels=filters[i],
|
| 623 |
+
out_channels=filters[i + 1],
|
| 624 |
+
kernel_size=(3, 3),
|
| 625 |
+
stride=(2, 2),
|
| 626 |
+
padding=(1, 1),
|
| 627 |
+
)
|
| 628 |
+
)
|
| 629 |
+
for i in range(K)
|
| 630 |
+
]
|
| 631 |
+
self.convs = nn.ModuleList(convs)
|
| 632 |
+
# self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
|
| 633 |
+
|
| 634 |
+
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
|
| 635 |
+
self.gru = nn.GRU(
|
| 636 |
+
input_size=ref_enc_filters[-1] * out_channels,
|
| 637 |
+
hidden_size=256 // 2,
|
| 638 |
+
batch_first=True,
|
| 639 |
+
)
|
| 640 |
+
self.proj = nn.Linear(128, gin_channels)
|
| 641 |
+
|
| 642 |
+
def forward(self, inputs):
|
| 643 |
+
N = inputs.size(0)
|
| 644 |
+
out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
|
| 645 |
+
for conv in self.convs:
|
| 646 |
+
out = conv(out)
|
| 647 |
+
# out = wn(out)
|
| 648 |
+
out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
|
| 649 |
+
|
| 650 |
+
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
|
| 651 |
+
T = out.size(1)
|
| 652 |
+
N = out.size(0)
|
| 653 |
+
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
|
| 654 |
+
|
| 655 |
+
self.gru.flatten_parameters()
|
| 656 |
+
memory, out = self.gru(out) # out --- [1, N, 128]
|
| 657 |
+
|
| 658 |
+
return self.proj(out.squeeze(0)).unsqueeze(-1)
|
| 659 |
+
|
| 660 |
+
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
|
| 661 |
+
for i in range(n_convs):
|
| 662 |
+
L = (L - kernel_size + 2 * pad) // stride + 1
|
| 663 |
+
return L
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
class Quantizer_module(torch.nn.Module):
|
| 667 |
+
def __init__(self, n_e, e_dim):
|
| 668 |
+
super(Quantizer_module, self).__init__()
|
| 669 |
+
self.embedding = nn.Embedding(n_e, e_dim)
|
| 670 |
+
self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e)
|
| 671 |
+
|
| 672 |
+
def forward(self, x):
|
| 673 |
+
d = (
|
| 674 |
+
torch.sum(x**2, 1, keepdim=True)
|
| 675 |
+
+ torch.sum(self.embedding.weight**2, 1)
|
| 676 |
+
- 2 * torch.matmul(x, self.embedding.weight.T)
|
| 677 |
+
)
|
| 678 |
+
min_indicies = torch.argmin(d, 1)
|
| 679 |
+
z_q = self.embedding(min_indicies)
|
| 680 |
+
return z_q, min_indicies
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
class Quantizer(torch.nn.Module):
|
| 684 |
+
def __init__(self, embed_dim=512, n_code_groups=4, n_codes=160):
|
| 685 |
+
super(Quantizer, self).__init__()
|
| 686 |
+
assert embed_dim % n_code_groups == 0
|
| 687 |
+
self.quantizer_modules = nn.ModuleList(
|
| 688 |
+
[
|
| 689 |
+
Quantizer_module(n_codes, embed_dim // n_code_groups)
|
| 690 |
+
for _ in range(n_code_groups)
|
| 691 |
+
]
|
| 692 |
+
)
|
| 693 |
+
self.n_code_groups = n_code_groups
|
| 694 |
+
self.embed_dim = embed_dim
|
| 695 |
+
|
| 696 |
+
def forward(self, xin):
|
| 697 |
+
# B, C, T
|
| 698 |
+
B, C, T = xin.shape
|
| 699 |
+
xin = xin.transpose(1, 2)
|
| 700 |
+
x = xin.reshape(-1, self.embed_dim)
|
| 701 |
+
x = torch.split(x, self.embed_dim // self.n_code_groups, dim=-1)
|
| 702 |
+
min_indicies = []
|
| 703 |
+
z_q = []
|
| 704 |
+
for _x, m in zip(x, self.quantizer_modules):
|
| 705 |
+
_z_q, _min_indicies = m(_x)
|
| 706 |
+
z_q.append(_z_q)
|
| 707 |
+
min_indicies.append(_min_indicies) # B * T,
|
| 708 |
+
z_q = torch.cat(z_q, -1).reshape(xin.shape)
|
| 709 |
+
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
|
| 710 |
+
(z_q - xin.detach()) ** 2
|
| 711 |
+
)
|
| 712 |
+
z_q = xin + (z_q - xin).detach()
|
| 713 |
+
z_q = z_q.transpose(1, 2)
|
| 714 |
+
codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
|
| 715 |
+
return z_q, loss, codes.transpose(1, 2)
|
| 716 |
+
|
| 717 |
+
def embed(self, x):
|
| 718 |
+
# idx: N, 4, T
|
| 719 |
+
x = x.transpose(1, 2)
|
| 720 |
+
x = torch.split(x, 1, 2)
|
| 721 |
+
ret = []
|
| 722 |
+
for q, embed in zip(x, self.quantizer_modules):
|
| 723 |
+
q = embed.embedding(q.squeeze(-1))
|
| 724 |
+
ret.append(q)
|
| 725 |
+
ret = torch.cat(ret, -1)
|
| 726 |
+
return ret.transpose(1, 2) # N, C, T
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
class CodePredictor(nn.Module):
|
| 730 |
+
def __init__(
|
| 731 |
+
self,
|
| 732 |
+
hidden_channels,
|
| 733 |
+
filter_channels,
|
| 734 |
+
n_heads,
|
| 735 |
+
n_layers,
|
| 736 |
+
kernel_size,
|
| 737 |
+
p_dropout,
|
| 738 |
+
n_q=8,
|
| 739 |
+
dims=1024,
|
| 740 |
+
ssl_dim=768,
|
| 741 |
+
):
|
| 742 |
+
super().__init__()
|
| 743 |
+
self.hidden_channels = hidden_channels
|
| 744 |
+
self.filter_channels = filter_channels
|
| 745 |
+
self.n_heads = n_heads
|
| 746 |
+
self.n_layers = n_layers
|
| 747 |
+
self.kernel_size = kernel_size
|
| 748 |
+
self.p_dropout = p_dropout
|
| 749 |
+
|
| 750 |
+
self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
|
| 751 |
+
self.ref_enc = modules.MelStyleEncoder(
|
| 752 |
+
ssl_dim, style_vector_dim=hidden_channels
|
| 753 |
+
)
|
| 754 |
+
|
| 755 |
+
self.encoder = attentions.Encoder(
|
| 756 |
+
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
|
| 760 |
+
self.n_q = n_q
|
| 761 |
+
self.dims = dims
|
| 762 |
+
|
| 763 |
+
def forward(self, x, x_mask, refer, codes, infer=False):
|
| 764 |
+
x = x.detach()
|
| 765 |
+
x = self.vq_proj(x * x_mask) * x_mask
|
| 766 |
+
g = self.ref_enc(refer, x_mask)
|
| 767 |
+
x = x + g
|
| 768 |
+
x = self.encoder(x * x_mask, x_mask)
|
| 769 |
+
x = self.out_proj(x * x_mask) * x_mask
|
| 770 |
+
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
|
| 771 |
+
2, 3
|
| 772 |
+
)
|
| 773 |
+
target = codes[1:].transpose(0, 1)
|
| 774 |
+
if not infer:
|
| 775 |
+
logits = logits.reshape(-1, self.dims)
|
| 776 |
+
target = target.reshape(-1)
|
| 777 |
+
loss = torch.nn.functional.cross_entropy(logits, target)
|
| 778 |
+
return loss
|
| 779 |
+
else:
|
| 780 |
+
_, top10_preds = torch.topk(logits, 10, dim=-1)
|
| 781 |
+
correct_top10 = torch.any(top10_preds == target.unsqueeze(-1), dim=-1)
|
| 782 |
+
top3_acc = 100 * torch.mean(correct_top10.float()).detach().cpu().item()
|
| 783 |
+
|
| 784 |
+
print("Top-10 Accuracy:", top3_acc, "%")
|
| 785 |
+
|
| 786 |
+
pred_codes = torch.argmax(logits, dim=-1)
|
| 787 |
+
acc = 100 * torch.mean((pred_codes == target).float()).detach().cpu().item()
|
| 788 |
+
print("Top-1 Accuracy:", acc, "%")
|
| 789 |
+
|
| 790 |
+
return pred_codes.transpose(0, 1)
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
class SynthesizerTrn(nn.Module):
|
| 794 |
+
"""
|
| 795 |
+
Synthesizer for Training
|
| 796 |
+
"""
|
| 797 |
+
|
| 798 |
+
def __init__(
|
| 799 |
+
self,
|
| 800 |
+
spec_channels,
|
| 801 |
+
segment_size,
|
| 802 |
+
inter_channels,
|
| 803 |
+
hidden_channels,
|
| 804 |
+
filter_channels,
|
| 805 |
+
n_heads,
|
| 806 |
+
n_layers,
|
| 807 |
+
kernel_size,
|
| 808 |
+
p_dropout,
|
| 809 |
+
resblock,
|
| 810 |
+
resblock_kernel_sizes,
|
| 811 |
+
resblock_dilation_sizes,
|
| 812 |
+
upsample_rates,
|
| 813 |
+
upsample_initial_channel,
|
| 814 |
+
upsample_kernel_sizes,
|
| 815 |
+
n_speakers=0,
|
| 816 |
+
gin_channels=0,
|
| 817 |
+
use_sdp=True,
|
| 818 |
+
semantic_frame_rate=None,
|
| 819 |
+
freeze_quantizer=None,
|
| 820 |
+
**kwargs
|
| 821 |
+
):
|
| 822 |
+
super().__init__()
|
| 823 |
+
self.spec_channels = spec_channels
|
| 824 |
+
self.inter_channels = inter_channels
|
| 825 |
+
self.hidden_channels = hidden_channels
|
| 826 |
+
self.filter_channels = filter_channels
|
| 827 |
+
self.n_heads = n_heads
|
| 828 |
+
self.n_layers = n_layers
|
| 829 |
+
self.kernel_size = kernel_size
|
| 830 |
+
self.p_dropout = p_dropout
|
| 831 |
+
self.resblock = resblock
|
| 832 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
| 833 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
| 834 |
+
self.upsample_rates = upsample_rates
|
| 835 |
+
self.upsample_initial_channel = upsample_initial_channel
|
| 836 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
| 837 |
+
self.segment_size = segment_size
|
| 838 |
+
self.n_speakers = n_speakers
|
| 839 |
+
self.gin_channels = gin_channels
|
| 840 |
+
|
| 841 |
+
self.use_sdp = use_sdp
|
| 842 |
+
self.enc_p = TextEncoder(
|
| 843 |
+
inter_channels,
|
| 844 |
+
hidden_channels,
|
| 845 |
+
filter_channels,
|
| 846 |
+
n_heads,
|
| 847 |
+
n_layers,
|
| 848 |
+
kernel_size,
|
| 849 |
+
p_dropout,
|
| 850 |
+
)
|
| 851 |
+
self.dec = Generator(
|
| 852 |
+
inter_channels,
|
| 853 |
+
resblock,
|
| 854 |
+
resblock_kernel_sizes,
|
| 855 |
+
resblock_dilation_sizes,
|
| 856 |
+
upsample_rates,
|
| 857 |
+
upsample_initial_channel,
|
| 858 |
+
upsample_kernel_sizes,
|
| 859 |
+
gin_channels=gin_channels,
|
| 860 |
+
)
|
| 861 |
+
self.enc_q = PosteriorEncoder(
|
| 862 |
+
spec_channels,
|
| 863 |
+
inter_channels,
|
| 864 |
+
hidden_channels,
|
| 865 |
+
5,
|
| 866 |
+
1,
|
| 867 |
+
16,
|
| 868 |
+
gin_channels=gin_channels,
|
| 869 |
+
)
|
| 870 |
+
self.flow = ResidualCouplingBlock(
|
| 871 |
+
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
self.ref_enc = modules.MelStyleEncoder(
|
| 875 |
+
spec_channels, style_vector_dim=gin_channels
|
| 876 |
+
)
|
| 877 |
+
|
| 878 |
+
ssl_dim = 768
|
| 879 |
+
self.ssl_dim = ssl_dim
|
| 880 |
+
assert semantic_frame_rate in ["25hz", "50hz"]
|
| 881 |
+
self.semantic_frame_rate = semantic_frame_rate
|
| 882 |
+
if semantic_frame_rate == "25hz":
|
| 883 |
+
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
|
| 884 |
+
else:
|
| 885 |
+
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
|
| 886 |
+
|
| 887 |
+
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
| 888 |
+
if freeze_quantizer:
|
| 889 |
+
self.ssl_proj.requires_grad_(False)
|
| 890 |
+
self.quantizer.requires_grad_(False)
|
| 891 |
+
# self.enc_p.text_embedding.requires_grad_(False)
|
| 892 |
+
# self.enc_p.encoder_text.requires_grad_(False)
|
| 893 |
+
# self.enc_p.mrte.requires_grad_(False)
|
| 894 |
+
|
| 895 |
+
def forward(self, codes, text, refer):
|
| 896 |
+
refer_mask = torch.ones_like(refer[:1,:1,:])
|
| 897 |
+
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
| 898 |
+
|
| 899 |
+
quantized = self.quantizer.decode(codes)
|
| 900 |
+
if self.semantic_frame_rate == "25hz":
|
| 901 |
+
dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
|
| 902 |
+
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
|
| 903 |
+
|
| 904 |
+
x, m_p, logs_p, y_mask = self.enc_p(
|
| 905 |
+
quantized, text, ge
|
| 906 |
+
)
|
| 907 |
+
|
| 908 |
+
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p)
|
| 909 |
+
|
| 910 |
+
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
| 911 |
+
|
| 912 |
+
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
| 913 |
+
return o
|
| 914 |
+
|
| 915 |
+
def extract_latent(self, x):
|
| 916 |
+
ssl = self.ssl_proj(x)
|
| 917 |
+
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
| 918 |
+
return codes.transpose(0, 1)
|
module/modules.py
ADDED
|
@@ -0,0 +1,923 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
from torch.nn import Conv1d
|
| 8 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
| 9 |
+
|
| 10 |
+
from module import commons
|
| 11 |
+
from module.commons import init_weights, get_padding
|
| 12 |
+
from module.transforms import piecewise_rational_quadratic_transform
|
| 13 |
+
import torch.distributions as D
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
LRELU_SLOPE = 0.1
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class LayerNorm(nn.Module):
|
| 20 |
+
def __init__(self, channels, eps=1e-5):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.channels = channels
|
| 23 |
+
self.eps = eps
|
| 24 |
+
|
| 25 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
| 26 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
x = x.transpose(1, -1)
|
| 30 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
| 31 |
+
return x.transpose(1, -1)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ConvReluNorm(nn.Module):
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
in_channels,
|
| 38 |
+
hidden_channels,
|
| 39 |
+
out_channels,
|
| 40 |
+
kernel_size,
|
| 41 |
+
n_layers,
|
| 42 |
+
p_dropout,
|
| 43 |
+
):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.in_channels = in_channels
|
| 46 |
+
self.hidden_channels = hidden_channels
|
| 47 |
+
self.out_channels = out_channels
|
| 48 |
+
self.kernel_size = kernel_size
|
| 49 |
+
self.n_layers = n_layers
|
| 50 |
+
self.p_dropout = p_dropout
|
| 51 |
+
assert n_layers > 1, "Number of layers should be larger than 0."
|
| 52 |
+
|
| 53 |
+
self.conv_layers = nn.ModuleList()
|
| 54 |
+
self.norm_layers = nn.ModuleList()
|
| 55 |
+
self.conv_layers.append(
|
| 56 |
+
nn.Conv1d(
|
| 57 |
+
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
| 58 |
+
)
|
| 59 |
+
)
|
| 60 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
| 61 |
+
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
|
| 62 |
+
for _ in range(n_layers - 1):
|
| 63 |
+
self.conv_layers.append(
|
| 64 |
+
nn.Conv1d(
|
| 65 |
+
hidden_channels,
|
| 66 |
+
hidden_channels,
|
| 67 |
+
kernel_size,
|
| 68 |
+
padding=kernel_size // 2,
|
| 69 |
+
)
|
| 70 |
+
)
|
| 71 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
| 72 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
| 73 |
+
self.proj.weight.data.zero_()
|
| 74 |
+
self.proj.bias.data.zero_()
|
| 75 |
+
|
| 76 |
+
def forward(self, x, x_mask):
|
| 77 |
+
x_org = x
|
| 78 |
+
for i in range(self.n_layers):
|
| 79 |
+
x = self.conv_layers[i](x * x_mask)
|
| 80 |
+
x = self.norm_layers[i](x)
|
| 81 |
+
x = self.relu_drop(x)
|
| 82 |
+
x = x_org + self.proj(x)
|
| 83 |
+
return x * x_mask
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class DDSConv(nn.Module):
|
| 87 |
+
"""
|
| 88 |
+
Dialted and Depth-Separable Convolution
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.channels = channels
|
| 94 |
+
self.kernel_size = kernel_size
|
| 95 |
+
self.n_layers = n_layers
|
| 96 |
+
self.p_dropout = p_dropout
|
| 97 |
+
|
| 98 |
+
self.drop = nn.Dropout(p_dropout)
|
| 99 |
+
self.convs_sep = nn.ModuleList()
|
| 100 |
+
self.convs_1x1 = nn.ModuleList()
|
| 101 |
+
self.norms_1 = nn.ModuleList()
|
| 102 |
+
self.norms_2 = nn.ModuleList()
|
| 103 |
+
for i in range(n_layers):
|
| 104 |
+
dilation = kernel_size**i
|
| 105 |
+
padding = (kernel_size * dilation - dilation) // 2
|
| 106 |
+
self.convs_sep.append(
|
| 107 |
+
nn.Conv1d(
|
| 108 |
+
channels,
|
| 109 |
+
channels,
|
| 110 |
+
kernel_size,
|
| 111 |
+
groups=channels,
|
| 112 |
+
dilation=dilation,
|
| 113 |
+
padding=padding,
|
| 114 |
+
)
|
| 115 |
+
)
|
| 116 |
+
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
| 117 |
+
self.norms_1.append(LayerNorm(channels))
|
| 118 |
+
self.norms_2.append(LayerNorm(channels))
|
| 119 |
+
|
| 120 |
+
def forward(self, x, x_mask, g=None):
|
| 121 |
+
if g is not None:
|
| 122 |
+
x = x + g
|
| 123 |
+
for i in range(self.n_layers):
|
| 124 |
+
y = self.convs_sep[i](x * x_mask)
|
| 125 |
+
y = self.norms_1[i](y)
|
| 126 |
+
y = F.gelu(y)
|
| 127 |
+
y = self.convs_1x1[i](y)
|
| 128 |
+
y = self.norms_2[i](y)
|
| 129 |
+
y = F.gelu(y)
|
| 130 |
+
y = self.drop(y)
|
| 131 |
+
x = x + y
|
| 132 |
+
return x * x_mask
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class WN(torch.nn.Module):
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
hidden_channels,
|
| 139 |
+
kernel_size,
|
| 140 |
+
dilation_rate,
|
| 141 |
+
n_layers,
|
| 142 |
+
gin_channels=0,
|
| 143 |
+
p_dropout=0,
|
| 144 |
+
):
|
| 145 |
+
super(WN, self).__init__()
|
| 146 |
+
assert kernel_size % 2 == 1
|
| 147 |
+
self.hidden_channels = hidden_channels
|
| 148 |
+
self.kernel_size = (kernel_size,)
|
| 149 |
+
self.dilation_rate = dilation_rate
|
| 150 |
+
self.n_layers = n_layers
|
| 151 |
+
self.gin_channels = gin_channels
|
| 152 |
+
self.p_dropout = p_dropout
|
| 153 |
+
|
| 154 |
+
self.in_layers = torch.nn.ModuleList()
|
| 155 |
+
self.res_skip_layers = torch.nn.ModuleList()
|
| 156 |
+
self.drop = nn.Dropout(p_dropout)
|
| 157 |
+
|
| 158 |
+
if gin_channels != 0:
|
| 159 |
+
cond_layer = torch.nn.Conv1d(
|
| 160 |
+
gin_channels, 2 * hidden_channels * n_layers, 1
|
| 161 |
+
)
|
| 162 |
+
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
|
| 163 |
+
|
| 164 |
+
for i in range(n_layers):
|
| 165 |
+
dilation = dilation_rate**i
|
| 166 |
+
padding = int((kernel_size * dilation - dilation) / 2)
|
| 167 |
+
in_layer = torch.nn.Conv1d(
|
| 168 |
+
hidden_channels,
|
| 169 |
+
2 * hidden_channels,
|
| 170 |
+
kernel_size,
|
| 171 |
+
dilation=dilation,
|
| 172 |
+
padding=padding,
|
| 173 |
+
)
|
| 174 |
+
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
|
| 175 |
+
self.in_layers.append(in_layer)
|
| 176 |
+
|
| 177 |
+
# last one is not necessary
|
| 178 |
+
if i < n_layers - 1:
|
| 179 |
+
res_skip_channels = 2 * hidden_channels
|
| 180 |
+
else:
|
| 181 |
+
res_skip_channels = hidden_channels
|
| 182 |
+
|
| 183 |
+
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
| 184 |
+
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
|
| 185 |
+
self.res_skip_layers.append(res_skip_layer)
|
| 186 |
+
|
| 187 |
+
def forward(self, x, x_mask, g=None, **kwargs):
|
| 188 |
+
output = torch.zeros_like(x)
|
| 189 |
+
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
| 190 |
+
|
| 191 |
+
if g is not None:
|
| 192 |
+
g = self.cond_layer(g)
|
| 193 |
+
|
| 194 |
+
for i in range(self.n_layers):
|
| 195 |
+
x_in = self.in_layers[i](x)
|
| 196 |
+
if g is not None:
|
| 197 |
+
cond_offset = i * 2 * self.hidden_channels
|
| 198 |
+
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
| 199 |
+
else:
|
| 200 |
+
g_l = torch.zeros_like(x_in)
|
| 201 |
+
|
| 202 |
+
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
| 203 |
+
acts = self.drop(acts)
|
| 204 |
+
|
| 205 |
+
res_skip_acts = self.res_skip_layers[i](acts)
|
| 206 |
+
if i < self.n_layers - 1:
|
| 207 |
+
res_acts = res_skip_acts[:, : self.hidden_channels, :]
|
| 208 |
+
x = (x + res_acts) * x_mask
|
| 209 |
+
output = output + res_skip_acts[:, self.hidden_channels :, :]
|
| 210 |
+
else:
|
| 211 |
+
output = output + res_skip_acts
|
| 212 |
+
return output * x_mask
|
| 213 |
+
|
| 214 |
+
def remove_weight_norm(self):
|
| 215 |
+
if self.gin_channels != 0:
|
| 216 |
+
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
| 217 |
+
for l in self.in_layers:
|
| 218 |
+
torch.nn.utils.remove_weight_norm(l)
|
| 219 |
+
for l in self.res_skip_layers:
|
| 220 |
+
torch.nn.utils.remove_weight_norm(l)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class ResBlock1(torch.nn.Module):
|
| 224 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
| 225 |
+
super(ResBlock1, self).__init__()
|
| 226 |
+
self.convs1 = nn.ModuleList(
|
| 227 |
+
[
|
| 228 |
+
weight_norm(
|
| 229 |
+
Conv1d(
|
| 230 |
+
channels,
|
| 231 |
+
channels,
|
| 232 |
+
kernel_size,
|
| 233 |
+
1,
|
| 234 |
+
dilation=dilation[0],
|
| 235 |
+
padding=get_padding(kernel_size, dilation[0]),
|
| 236 |
+
)
|
| 237 |
+
),
|
| 238 |
+
weight_norm(
|
| 239 |
+
Conv1d(
|
| 240 |
+
channels,
|
| 241 |
+
channels,
|
| 242 |
+
kernel_size,
|
| 243 |
+
1,
|
| 244 |
+
dilation=dilation[1],
|
| 245 |
+
padding=get_padding(kernel_size, dilation[1]),
|
| 246 |
+
)
|
| 247 |
+
),
|
| 248 |
+
weight_norm(
|
| 249 |
+
Conv1d(
|
| 250 |
+
channels,
|
| 251 |
+
channels,
|
| 252 |
+
kernel_size,
|
| 253 |
+
1,
|
| 254 |
+
dilation=dilation[2],
|
| 255 |
+
padding=get_padding(kernel_size, dilation[2]),
|
| 256 |
+
)
|
| 257 |
+
),
|
| 258 |
+
]
|
| 259 |
+
)
|
| 260 |
+
self.convs1.apply(init_weights)
|
| 261 |
+
|
| 262 |
+
self.convs2 = nn.ModuleList(
|
| 263 |
+
[
|
| 264 |
+
weight_norm(
|
| 265 |
+
Conv1d(
|
| 266 |
+
channels,
|
| 267 |
+
channels,
|
| 268 |
+
kernel_size,
|
| 269 |
+
1,
|
| 270 |
+
dilation=1,
|
| 271 |
+
padding=get_padding(kernel_size, 1),
|
| 272 |
+
)
|
| 273 |
+
),
|
| 274 |
+
weight_norm(
|
| 275 |
+
Conv1d(
|
| 276 |
+
channels,
|
| 277 |
+
channels,
|
| 278 |
+
kernel_size,
|
| 279 |
+
1,
|
| 280 |
+
dilation=1,
|
| 281 |
+
padding=get_padding(kernel_size, 1),
|
| 282 |
+
)
|
| 283 |
+
),
|
| 284 |
+
weight_norm(
|
| 285 |
+
Conv1d(
|
| 286 |
+
channels,
|
| 287 |
+
channels,
|
| 288 |
+
kernel_size,
|
| 289 |
+
1,
|
| 290 |
+
dilation=1,
|
| 291 |
+
padding=get_padding(kernel_size, 1),
|
| 292 |
+
)
|
| 293 |
+
),
|
| 294 |
+
]
|
| 295 |
+
)
|
| 296 |
+
self.convs2.apply(init_weights)
|
| 297 |
+
|
| 298 |
+
def forward(self, x, x_mask=None):
|
| 299 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
| 300 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
| 301 |
+
if x_mask is not None:
|
| 302 |
+
xt = xt * x_mask
|
| 303 |
+
xt = c1(xt)
|
| 304 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
| 305 |
+
if x_mask is not None:
|
| 306 |
+
xt = xt * x_mask
|
| 307 |
+
xt = c2(xt)
|
| 308 |
+
x = xt + x
|
| 309 |
+
if x_mask is not None:
|
| 310 |
+
x = x * x_mask
|
| 311 |
+
return x
|
| 312 |
+
|
| 313 |
+
def remove_weight_norm(self):
|
| 314 |
+
for l in self.convs1:
|
| 315 |
+
remove_weight_norm(l)
|
| 316 |
+
for l in self.convs2:
|
| 317 |
+
remove_weight_norm(l)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class ResBlock2(torch.nn.Module):
|
| 321 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
| 322 |
+
super(ResBlock2, self).__init__()
|
| 323 |
+
self.convs = nn.ModuleList(
|
| 324 |
+
[
|
| 325 |
+
weight_norm(
|
| 326 |
+
Conv1d(
|
| 327 |
+
channels,
|
| 328 |
+
channels,
|
| 329 |
+
kernel_size,
|
| 330 |
+
1,
|
| 331 |
+
dilation=dilation[0],
|
| 332 |
+
padding=get_padding(kernel_size, dilation[0]),
|
| 333 |
+
)
|
| 334 |
+
),
|
| 335 |
+
weight_norm(
|
| 336 |
+
Conv1d(
|
| 337 |
+
channels,
|
| 338 |
+
channels,
|
| 339 |
+
kernel_size,
|
| 340 |
+
1,
|
| 341 |
+
dilation=dilation[1],
|
| 342 |
+
padding=get_padding(kernel_size, dilation[1]),
|
| 343 |
+
)
|
| 344 |
+
),
|
| 345 |
+
]
|
| 346 |
+
)
|
| 347 |
+
self.convs.apply(init_weights)
|
| 348 |
+
|
| 349 |
+
def forward(self, x, x_mask=None):
|
| 350 |
+
for c in self.convs:
|
| 351 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
| 352 |
+
if x_mask is not None:
|
| 353 |
+
xt = xt * x_mask
|
| 354 |
+
xt = c(xt)
|
| 355 |
+
x = xt + x
|
| 356 |
+
if x_mask is not None:
|
| 357 |
+
x = x * x_mask
|
| 358 |
+
return x
|
| 359 |
+
|
| 360 |
+
def remove_weight_norm(self):
|
| 361 |
+
for l in self.convs:
|
| 362 |
+
remove_weight_norm(l)
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
class Log(nn.Module):
|
| 366 |
+
def forward(self, x, x_mask, reverse=False, **kwargs):
|
| 367 |
+
if not reverse:
|
| 368 |
+
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
|
| 369 |
+
logdet = torch.sum(-y, [1, 2])
|
| 370 |
+
return y, logdet
|
| 371 |
+
else:
|
| 372 |
+
x = torch.exp(x) * x_mask
|
| 373 |
+
return x
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
class Flip(nn.Module):
|
| 377 |
+
def forward(self, x, *args, reverse=False, **kwargs):
|
| 378 |
+
x = torch.flip(x, [1])
|
| 379 |
+
if not reverse:
|
| 380 |
+
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
| 381 |
+
return x, logdet
|
| 382 |
+
else:
|
| 383 |
+
return x
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
class ElementwiseAffine(nn.Module):
|
| 387 |
+
def __init__(self, channels):
|
| 388 |
+
super().__init__()
|
| 389 |
+
self.channels = channels
|
| 390 |
+
self.m = nn.Parameter(torch.zeros(channels, 1))
|
| 391 |
+
self.logs = nn.Parameter(torch.zeros(channels, 1))
|
| 392 |
+
|
| 393 |
+
def forward(self, x, x_mask, reverse=False, **kwargs):
|
| 394 |
+
if not reverse:
|
| 395 |
+
y = self.m + torch.exp(self.logs) * x
|
| 396 |
+
y = y * x_mask
|
| 397 |
+
logdet = torch.sum(self.logs * x_mask, [1, 2])
|
| 398 |
+
return y, logdet
|
| 399 |
+
else:
|
| 400 |
+
x = (x - self.m) * torch.exp(-self.logs) * x_mask
|
| 401 |
+
return x
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
class ResidualCouplingLayer(nn.Module):
|
| 405 |
+
def __init__(
|
| 406 |
+
self,
|
| 407 |
+
channels,
|
| 408 |
+
hidden_channels,
|
| 409 |
+
kernel_size,
|
| 410 |
+
dilation_rate,
|
| 411 |
+
n_layers,
|
| 412 |
+
p_dropout=0,
|
| 413 |
+
gin_channels=0,
|
| 414 |
+
mean_only=False,
|
| 415 |
+
):
|
| 416 |
+
assert channels % 2 == 0, "channels should be divisible by 2"
|
| 417 |
+
super().__init__()
|
| 418 |
+
self.channels = channels
|
| 419 |
+
self.hidden_channels = hidden_channels
|
| 420 |
+
self.kernel_size = kernel_size
|
| 421 |
+
self.dilation_rate = dilation_rate
|
| 422 |
+
self.n_layers = n_layers
|
| 423 |
+
self.half_channels = channels // 2
|
| 424 |
+
self.mean_only = mean_only
|
| 425 |
+
|
| 426 |
+
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
| 427 |
+
self.enc = WN(
|
| 428 |
+
hidden_channels,
|
| 429 |
+
kernel_size,
|
| 430 |
+
dilation_rate,
|
| 431 |
+
n_layers,
|
| 432 |
+
p_dropout=p_dropout,
|
| 433 |
+
gin_channels=gin_channels,
|
| 434 |
+
)
|
| 435 |
+
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
| 436 |
+
self.post.weight.data.zero_()
|
| 437 |
+
self.post.bias.data.zero_()
|
| 438 |
+
|
| 439 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
| 440 |
+
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
| 441 |
+
h = self.pre(x0) * x_mask
|
| 442 |
+
h = self.enc(h, x_mask, g=g)
|
| 443 |
+
stats = self.post(h) * x_mask
|
| 444 |
+
if not self.mean_only:
|
| 445 |
+
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
| 446 |
+
else:
|
| 447 |
+
m = stats
|
| 448 |
+
logs = torch.zeros_like(m)
|
| 449 |
+
|
| 450 |
+
if not reverse:
|
| 451 |
+
x1 = m + x1 * torch.exp(logs) * x_mask
|
| 452 |
+
x = torch.cat([x0, x1], 1)
|
| 453 |
+
logdet = torch.sum(logs, [1, 2])
|
| 454 |
+
return x, logdet
|
| 455 |
+
else:
|
| 456 |
+
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
| 457 |
+
x = torch.cat([x0, x1], 1)
|
| 458 |
+
return x
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
class ConvFlow(nn.Module):
|
| 462 |
+
def __init__(
|
| 463 |
+
self,
|
| 464 |
+
in_channels,
|
| 465 |
+
filter_channels,
|
| 466 |
+
kernel_size,
|
| 467 |
+
n_layers,
|
| 468 |
+
num_bins=10,
|
| 469 |
+
tail_bound=5.0,
|
| 470 |
+
):
|
| 471 |
+
super().__init__()
|
| 472 |
+
self.in_channels = in_channels
|
| 473 |
+
self.filter_channels = filter_channels
|
| 474 |
+
self.kernel_size = kernel_size
|
| 475 |
+
self.n_layers = n_layers
|
| 476 |
+
self.num_bins = num_bins
|
| 477 |
+
self.tail_bound = tail_bound
|
| 478 |
+
self.half_channels = in_channels // 2
|
| 479 |
+
|
| 480 |
+
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
|
| 481 |
+
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
|
| 482 |
+
self.proj = nn.Conv1d(
|
| 483 |
+
filter_channels, self.half_channels * (num_bins * 3 - 1), 1
|
| 484 |
+
)
|
| 485 |
+
self.proj.weight.data.zero_()
|
| 486 |
+
self.proj.bias.data.zero_()
|
| 487 |
+
|
| 488 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
| 489 |
+
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
| 490 |
+
h = self.pre(x0)
|
| 491 |
+
h = self.convs(h, x_mask, g=g)
|
| 492 |
+
h = self.proj(h) * x_mask
|
| 493 |
+
|
| 494 |
+
b, c, t = x0.shape
|
| 495 |
+
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
|
| 496 |
+
|
| 497 |
+
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
|
| 498 |
+
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
|
| 499 |
+
self.filter_channels
|
| 500 |
+
)
|
| 501 |
+
unnormalized_derivatives = h[..., 2 * self.num_bins :]
|
| 502 |
+
|
| 503 |
+
x1, logabsdet = piecewise_rational_quadratic_transform(
|
| 504 |
+
x1,
|
| 505 |
+
unnormalized_widths,
|
| 506 |
+
unnormalized_heights,
|
| 507 |
+
unnormalized_derivatives,
|
| 508 |
+
inverse=reverse,
|
| 509 |
+
tails="linear",
|
| 510 |
+
tail_bound=self.tail_bound,
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
x = torch.cat([x0, x1], 1) * x_mask
|
| 514 |
+
logdet = torch.sum(logabsdet * x_mask, [1, 2])
|
| 515 |
+
if not reverse:
|
| 516 |
+
return x, logdet
|
| 517 |
+
else:
|
| 518 |
+
return x
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
class LinearNorm(nn.Module):
|
| 522 |
+
def __init__(
|
| 523 |
+
self,
|
| 524 |
+
in_channels,
|
| 525 |
+
out_channels,
|
| 526 |
+
bias=True,
|
| 527 |
+
spectral_norm=False,
|
| 528 |
+
):
|
| 529 |
+
super(LinearNorm, self).__init__()
|
| 530 |
+
self.fc = nn.Linear(in_channels, out_channels, bias)
|
| 531 |
+
|
| 532 |
+
if spectral_norm:
|
| 533 |
+
self.fc = nn.utils.spectral_norm(self.fc)
|
| 534 |
+
|
| 535 |
+
def forward(self, input):
|
| 536 |
+
out = self.fc(input)
|
| 537 |
+
return out
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
class Mish(nn.Module):
|
| 541 |
+
def __init__(self):
|
| 542 |
+
super(Mish, self).__init__()
|
| 543 |
+
|
| 544 |
+
def forward(self, x):
|
| 545 |
+
return x * torch.tanh(F.softplus(x))
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
class Conv1dGLU(nn.Module):
|
| 549 |
+
"""
|
| 550 |
+
Conv1d + GLU(Gated Linear Unit) with residual connection.
|
| 551 |
+
For GLU refer to https://arxiv.org/abs/1612.08083 paper.
|
| 552 |
+
"""
|
| 553 |
+
|
| 554 |
+
def __init__(self, in_channels, out_channels, kernel_size, dropout):
|
| 555 |
+
super(Conv1dGLU, self).__init__()
|
| 556 |
+
self.out_channels = out_channels
|
| 557 |
+
self.conv1 = ConvNorm(in_channels, 2 * out_channels, kernel_size=kernel_size)
|
| 558 |
+
self.dropout = nn.Dropout(dropout)
|
| 559 |
+
|
| 560 |
+
def forward(self, x):
|
| 561 |
+
residual = x
|
| 562 |
+
x = self.conv1(x)
|
| 563 |
+
x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1)
|
| 564 |
+
x = x1 * torch.sigmoid(x2)
|
| 565 |
+
x = residual + self.dropout(x)
|
| 566 |
+
return x
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
class ConvNorm(nn.Module):
|
| 570 |
+
def __init__(
|
| 571 |
+
self,
|
| 572 |
+
in_channels,
|
| 573 |
+
out_channels,
|
| 574 |
+
kernel_size=1,
|
| 575 |
+
stride=1,
|
| 576 |
+
padding=None,
|
| 577 |
+
dilation=1,
|
| 578 |
+
bias=True,
|
| 579 |
+
spectral_norm=False,
|
| 580 |
+
):
|
| 581 |
+
super(ConvNorm, self).__init__()
|
| 582 |
+
|
| 583 |
+
if padding is None:
|
| 584 |
+
assert kernel_size % 2 == 1
|
| 585 |
+
padding = int(dilation * (kernel_size - 1) / 2)
|
| 586 |
+
|
| 587 |
+
self.conv = torch.nn.Conv1d(
|
| 588 |
+
in_channels,
|
| 589 |
+
out_channels,
|
| 590 |
+
kernel_size=kernel_size,
|
| 591 |
+
stride=stride,
|
| 592 |
+
padding=padding,
|
| 593 |
+
dilation=dilation,
|
| 594 |
+
bias=bias,
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
if spectral_norm:
|
| 598 |
+
self.conv = nn.utils.spectral_norm(self.conv)
|
| 599 |
+
|
| 600 |
+
def forward(self, input):
|
| 601 |
+
out = self.conv(input)
|
| 602 |
+
return out
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
class MultiHeadAttention(nn.Module):
|
| 606 |
+
"""Multi-Head Attention module"""
|
| 607 |
+
|
| 608 |
+
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.0, spectral_norm=False):
|
| 609 |
+
super().__init__()
|
| 610 |
+
|
| 611 |
+
self.n_head = n_head
|
| 612 |
+
self.d_k = d_k
|
| 613 |
+
self.d_v = d_v
|
| 614 |
+
|
| 615 |
+
self.w_qs = nn.Linear(d_model, n_head * d_k)
|
| 616 |
+
self.w_ks = nn.Linear(d_model, n_head * d_k)
|
| 617 |
+
self.w_vs = nn.Linear(d_model, n_head * d_v)
|
| 618 |
+
|
| 619 |
+
self.attention = ScaledDotProductAttention(
|
| 620 |
+
temperature=np.power(d_model, 0.5), dropout=dropout
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
self.fc = nn.Linear(n_head * d_v, d_model)
|
| 624 |
+
self.dropout = nn.Dropout(dropout)
|
| 625 |
+
|
| 626 |
+
if spectral_norm:
|
| 627 |
+
self.w_qs = nn.utils.spectral_norm(self.w_qs)
|
| 628 |
+
self.w_ks = nn.utils.spectral_norm(self.w_ks)
|
| 629 |
+
self.w_vs = nn.utils.spectral_norm(self.w_vs)
|
| 630 |
+
self.fc = nn.utils.spectral_norm(self.fc)
|
| 631 |
+
|
| 632 |
+
def forward(self, x, mask=None):
|
| 633 |
+
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
|
| 634 |
+
sz_b, len_x, _ = x.size()
|
| 635 |
+
|
| 636 |
+
residual = x
|
| 637 |
+
|
| 638 |
+
q = self.w_qs(x).view(sz_b, len_x, n_head, d_k)
|
| 639 |
+
k = self.w_ks(x).view(sz_b, len_x, n_head, d_k)
|
| 640 |
+
v = self.w_vs(x).view(sz_b, len_x, n_head, d_v)
|
| 641 |
+
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k) # (n*b) x lq x dk
|
| 642 |
+
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k) # (n*b) x lk x dk
|
| 643 |
+
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_v) # (n*b) x lv x dv
|
| 644 |
+
|
| 645 |
+
if mask is not None:
|
| 646 |
+
slf_mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
|
| 647 |
+
else:
|
| 648 |
+
slf_mask = None
|
| 649 |
+
output, attn = self.attention(q, k, v, mask=slf_mask)
|
| 650 |
+
|
| 651 |
+
output = output.view(n_head, sz_b, len_x, d_v)
|
| 652 |
+
output = (
|
| 653 |
+
output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1)
|
| 654 |
+
) # b x lq x (n*dv)
|
| 655 |
+
|
| 656 |
+
output = self.fc(output)
|
| 657 |
+
|
| 658 |
+
output = self.dropout(output) + residual
|
| 659 |
+
return output, attn
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
class ScaledDotProductAttention(nn.Module):
|
| 663 |
+
"""Scaled Dot-Product Attention"""
|
| 664 |
+
|
| 665 |
+
def __init__(self, temperature, dropout):
|
| 666 |
+
super().__init__()
|
| 667 |
+
self.temperature = temperature
|
| 668 |
+
self.softmax = nn.Softmax(dim=2)
|
| 669 |
+
self.dropout = nn.Dropout(dropout)
|
| 670 |
+
|
| 671 |
+
def forward(self, q, k, v, mask=None):
|
| 672 |
+
attn = torch.bmm(q, k.transpose(1, 2))
|
| 673 |
+
attn = attn / self.temperature
|
| 674 |
+
|
| 675 |
+
if mask is not None:
|
| 676 |
+
attn = attn.masked_fill(mask, -np.inf)
|
| 677 |
+
|
| 678 |
+
attn = self.softmax(attn)
|
| 679 |
+
p_attn = self.dropout(attn)
|
| 680 |
+
|
| 681 |
+
output = torch.bmm(p_attn, v)
|
| 682 |
+
return output, attn
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
class MelStyleEncoder(nn.Module):
|
| 686 |
+
"""MelStyleEncoder"""
|
| 687 |
+
|
| 688 |
+
def __init__(
|
| 689 |
+
self,
|
| 690 |
+
n_mel_channels=80,
|
| 691 |
+
style_hidden=128,
|
| 692 |
+
style_vector_dim=256,
|
| 693 |
+
style_kernel_size=5,
|
| 694 |
+
style_head=2,
|
| 695 |
+
dropout=0.1,
|
| 696 |
+
):
|
| 697 |
+
super(MelStyleEncoder, self).__init__()
|
| 698 |
+
self.in_dim = n_mel_channels
|
| 699 |
+
self.hidden_dim = style_hidden
|
| 700 |
+
self.out_dim = style_vector_dim
|
| 701 |
+
self.kernel_size = style_kernel_size
|
| 702 |
+
self.n_head = style_head
|
| 703 |
+
self.dropout = dropout
|
| 704 |
+
|
| 705 |
+
self.spectral = nn.Sequential(
|
| 706 |
+
LinearNorm(self.in_dim, self.hidden_dim),
|
| 707 |
+
Mish(),
|
| 708 |
+
nn.Dropout(self.dropout),
|
| 709 |
+
LinearNorm(self.hidden_dim, self.hidden_dim),
|
| 710 |
+
Mish(),
|
| 711 |
+
nn.Dropout(self.dropout),
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
self.temporal = nn.Sequential(
|
| 715 |
+
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
|
| 716 |
+
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
+
self.slf_attn = MultiHeadAttention(
|
| 720 |
+
self.n_head,
|
| 721 |
+
self.hidden_dim,
|
| 722 |
+
self.hidden_dim // self.n_head,
|
| 723 |
+
self.hidden_dim // self.n_head,
|
| 724 |
+
self.dropout,
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
+
self.fc = LinearNorm(self.hidden_dim, self.out_dim)
|
| 728 |
+
|
| 729 |
+
def temporal_avg_pool(self, x, mask=None):
|
| 730 |
+
if mask is None:
|
| 731 |
+
out = torch.mean(x, dim=1)
|
| 732 |
+
else:
|
| 733 |
+
len_ = (~mask).sum(dim=1).unsqueeze(1)
|
| 734 |
+
x = x.masked_fill(mask.unsqueeze(-1), 0)
|
| 735 |
+
x = x.sum(dim=1)
|
| 736 |
+
out = torch.div(x, len_)
|
| 737 |
+
return out
|
| 738 |
+
|
| 739 |
+
def forward(self, x, mask=None):
|
| 740 |
+
x = x.transpose(1, 2)
|
| 741 |
+
if mask is not None:
|
| 742 |
+
mask = (mask.int() == 0).squeeze(1)
|
| 743 |
+
max_len = x.shape[1]
|
| 744 |
+
slf_attn_mask = (
|
| 745 |
+
mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
# spectral
|
| 749 |
+
x = self.spectral(x)
|
| 750 |
+
# temporal
|
| 751 |
+
x = x.transpose(1, 2)
|
| 752 |
+
x = self.temporal(x)
|
| 753 |
+
x = x.transpose(1, 2)
|
| 754 |
+
# self-attention
|
| 755 |
+
if mask is not None:
|
| 756 |
+
x = x.masked_fill(mask.unsqueeze(-1), 0)
|
| 757 |
+
x, _ = self.slf_attn(x, mask=slf_attn_mask)
|
| 758 |
+
# fc
|
| 759 |
+
x = self.fc(x)
|
| 760 |
+
# temoral average pooling
|
| 761 |
+
w = self.temporal_avg_pool(x, mask=mask)
|
| 762 |
+
|
| 763 |
+
return w.unsqueeze(-1)
|
| 764 |
+
|
| 765 |
+
|
| 766 |
+
class MelStyleEncoderVAE(nn.Module):
|
| 767 |
+
def __init__(self, spec_channels, z_latent_dim, emb_dim):
|
| 768 |
+
super().__init__()
|
| 769 |
+
self.ref_encoder = MelStyleEncoder(spec_channels, style_vector_dim=emb_dim)
|
| 770 |
+
self.fc1 = nn.Linear(emb_dim, z_latent_dim)
|
| 771 |
+
self.fc2 = nn.Linear(emb_dim, z_latent_dim)
|
| 772 |
+
self.fc3 = nn.Linear(z_latent_dim, emb_dim)
|
| 773 |
+
self.z_latent_dim = z_latent_dim
|
| 774 |
+
|
| 775 |
+
def reparameterize(self, mu, logvar):
|
| 776 |
+
if self.training:
|
| 777 |
+
std = torch.exp(0.5 * logvar)
|
| 778 |
+
eps = torch.randn_like(std)
|
| 779 |
+
return eps.mul(std).add_(mu)
|
| 780 |
+
else:
|
| 781 |
+
return mu
|
| 782 |
+
|
| 783 |
+
def forward(self, inputs, mask=None):
|
| 784 |
+
enc_out = self.ref_encoder(inputs.squeeze(-1), mask).squeeze(-1)
|
| 785 |
+
mu = self.fc1(enc_out)
|
| 786 |
+
logvar = self.fc2(enc_out)
|
| 787 |
+
posterior = D.Normal(mu, torch.exp(logvar))
|
| 788 |
+
kl_divergence = D.kl_divergence(
|
| 789 |
+
posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar))
|
| 790 |
+
)
|
| 791 |
+
loss_kl = kl_divergence.mean()
|
| 792 |
+
|
| 793 |
+
z = posterior.rsample()
|
| 794 |
+
style_embed = self.fc3(z)
|
| 795 |
+
|
| 796 |
+
return style_embed.unsqueeze(-1), loss_kl
|
| 797 |
+
|
| 798 |
+
def infer(self, inputs=None, random_sample=False, manual_latent=None):
|
| 799 |
+
if manual_latent is None:
|
| 800 |
+
if random_sample:
|
| 801 |
+
dev = next(self.parameters()).device
|
| 802 |
+
posterior = D.Normal(
|
| 803 |
+
torch.zeros(1, self.z_latent_dim, device=dev),
|
| 804 |
+
torch.ones(1, self.z_latent_dim, device=dev),
|
| 805 |
+
)
|
| 806 |
+
z = posterior.rsample()
|
| 807 |
+
else:
|
| 808 |
+
enc_out = self.ref_encoder(inputs.transpose(1, 2))
|
| 809 |
+
mu = self.fc1(enc_out)
|
| 810 |
+
z = mu
|
| 811 |
+
else:
|
| 812 |
+
z = manual_latent
|
| 813 |
+
style_embed = self.fc3(z)
|
| 814 |
+
return style_embed.unsqueeze(-1), z
|
| 815 |
+
|
| 816 |
+
|
| 817 |
+
class ActNorm(nn.Module):
|
| 818 |
+
def __init__(self, channels, ddi=False, **kwargs):
|
| 819 |
+
super().__init__()
|
| 820 |
+
self.channels = channels
|
| 821 |
+
self.initialized = not ddi
|
| 822 |
+
|
| 823 |
+
self.logs = nn.Parameter(torch.zeros(1, channels, 1))
|
| 824 |
+
self.bias = nn.Parameter(torch.zeros(1, channels, 1))
|
| 825 |
+
|
| 826 |
+
def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs):
|
| 827 |
+
if x_mask is None:
|
| 828 |
+
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(
|
| 829 |
+
device=x.device, dtype=x.dtype
|
| 830 |
+
)
|
| 831 |
+
x_len = torch.sum(x_mask, [1, 2])
|
| 832 |
+
if not self.initialized:
|
| 833 |
+
self.initialize(x, x_mask)
|
| 834 |
+
self.initialized = True
|
| 835 |
+
|
| 836 |
+
if reverse:
|
| 837 |
+
z = (x - self.bias) * torch.exp(-self.logs) * x_mask
|
| 838 |
+
logdet = None
|
| 839 |
+
return z
|
| 840 |
+
else:
|
| 841 |
+
z = (self.bias + torch.exp(self.logs) * x) * x_mask
|
| 842 |
+
logdet = torch.sum(self.logs) * x_len # [b]
|
| 843 |
+
return z, logdet
|
| 844 |
+
|
| 845 |
+
def store_inverse(self):
|
| 846 |
+
pass
|
| 847 |
+
|
| 848 |
+
def set_ddi(self, ddi):
|
| 849 |
+
self.initialized = not ddi
|
| 850 |
+
|
| 851 |
+
def initialize(self, x, x_mask):
|
| 852 |
+
with torch.no_grad():
|
| 853 |
+
denom = torch.sum(x_mask, [0, 2])
|
| 854 |
+
m = torch.sum(x * x_mask, [0, 2]) / denom
|
| 855 |
+
m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
|
| 856 |
+
v = m_sq - (m**2)
|
| 857 |
+
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
|
| 858 |
+
|
| 859 |
+
bias_init = (
|
| 860 |
+
(-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
|
| 861 |
+
)
|
| 862 |
+
logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
|
| 863 |
+
|
| 864 |
+
self.bias.data.copy_(bias_init)
|
| 865 |
+
self.logs.data.copy_(logs_init)
|
| 866 |
+
|
| 867 |
+
|
| 868 |
+
class InvConvNear(nn.Module):
|
| 869 |
+
def __init__(self, channels, n_split=4, no_jacobian=False, **kwargs):
|
| 870 |
+
super().__init__()
|
| 871 |
+
assert n_split % 2 == 0
|
| 872 |
+
self.channels = channels
|
| 873 |
+
self.n_split = n_split
|
| 874 |
+
self.no_jacobian = no_jacobian
|
| 875 |
+
|
| 876 |
+
w_init = torch.linalg.qr(
|
| 877 |
+
torch.FloatTensor(self.n_split, self.n_split).normal_()
|
| 878 |
+
)[0]
|
| 879 |
+
if torch.det(w_init) < 0:
|
| 880 |
+
w_init[:, 0] = -1 * w_init[:, 0]
|
| 881 |
+
self.weight = nn.Parameter(w_init)
|
| 882 |
+
|
| 883 |
+
def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs):
|
| 884 |
+
b, c, t = x.size()
|
| 885 |
+
assert c % self.n_split == 0
|
| 886 |
+
if x_mask is None:
|
| 887 |
+
x_mask = 1
|
| 888 |
+
x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
|
| 889 |
+
else:
|
| 890 |
+
x_len = torch.sum(x_mask, [1, 2])
|
| 891 |
+
|
| 892 |
+
x = x.view(b, 2, c // self.n_split, self.n_split // 2, t)
|
| 893 |
+
x = (
|
| 894 |
+
x.permute(0, 1, 3, 2, 4)
|
| 895 |
+
.contiguous()
|
| 896 |
+
.view(b, self.n_split, c // self.n_split, t)
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
if reverse:
|
| 900 |
+
if hasattr(self, "weight_inv"):
|
| 901 |
+
weight = self.weight_inv
|
| 902 |
+
else:
|
| 903 |
+
weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
|
| 904 |
+
logdet = None
|
| 905 |
+
else:
|
| 906 |
+
weight = self.weight
|
| 907 |
+
if self.no_jacobian:
|
| 908 |
+
logdet = 0
|
| 909 |
+
else:
|
| 910 |
+
logdet = torch.logdet(self.weight) * (c / self.n_split) * x_len # [b]
|
| 911 |
+
|
| 912 |
+
weight = weight.view(self.n_split, self.n_split, 1, 1)
|
| 913 |
+
z = F.conv2d(x, weight)
|
| 914 |
+
|
| 915 |
+
z = z.view(b, 2, self.n_split // 2, c // self.n_split, t)
|
| 916 |
+
z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask
|
| 917 |
+
if reverse:
|
| 918 |
+
return z
|
| 919 |
+
else:
|
| 920 |
+
return z, logdet
|
| 921 |
+
|
| 922 |
+
def store_inverse(self):
|
| 923 |
+
self.weight_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
|
module/mrte_model.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This is Multi-reference timbre encoder
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn.utils import remove_weight_norm, weight_norm
|
| 6 |
+
from module.attentions import MultiHeadAttention
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MRTE(nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
content_enc_channels=192,
|
| 13 |
+
hidden_size=512,
|
| 14 |
+
out_channels=192,
|
| 15 |
+
kernel_size=5,
|
| 16 |
+
n_heads=4,
|
| 17 |
+
ge_layer=2,
|
| 18 |
+
):
|
| 19 |
+
super(MRTE, self).__init__()
|
| 20 |
+
self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)
|
| 21 |
+
self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
|
| 22 |
+
self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
|
| 23 |
+
self.c_post = nn.Conv1d(hidden_size, out_channels, 1)
|
| 24 |
+
|
| 25 |
+
def forward(self, ssl_enc, ssl_mask, text, text_mask, ge, test=None):
|
| 26 |
+
if ge == None:
|
| 27 |
+
ge = 0
|
| 28 |
+
attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)
|
| 29 |
+
|
| 30 |
+
ssl_enc = self.c_pre(ssl_enc * ssl_mask)
|
| 31 |
+
text_enc = self.text_pre(text * text_mask)
|
| 32 |
+
if test != None:
|
| 33 |
+
if test == 0:
|
| 34 |
+
x = (
|
| 35 |
+
self.cross_attention(
|
| 36 |
+
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
|
| 37 |
+
)
|
| 38 |
+
+ ssl_enc
|
| 39 |
+
+ ge
|
| 40 |
+
)
|
| 41 |
+
elif test == 1:
|
| 42 |
+
x = ssl_enc + ge
|
| 43 |
+
elif test == 2:
|
| 44 |
+
x = (
|
| 45 |
+
self.cross_attention(
|
| 46 |
+
ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask
|
| 47 |
+
)
|
| 48 |
+
+ ge
|
| 49 |
+
)
|
| 50 |
+
else:
|
| 51 |
+
raise ValueError("test should be 0,1,2")
|
| 52 |
+
else:
|
| 53 |
+
x = (
|
| 54 |
+
self.cross_attention(
|
| 55 |
+
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
|
| 56 |
+
)
|
| 57 |
+
+ ssl_enc
|
| 58 |
+
+ ge
|
| 59 |
+
)
|
| 60 |
+
x = self.c_post(x * ssl_mask)
|
| 61 |
+
return x
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class SpeakerEncoder(torch.nn.Module):
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
mel_n_channels=80,
|
| 68 |
+
model_num_layers=2,
|
| 69 |
+
model_hidden_size=256,
|
| 70 |
+
model_embedding_size=256,
|
| 71 |
+
):
|
| 72 |
+
super(SpeakerEncoder, self).__init__()
|
| 73 |
+
self.lstm = nn.LSTM(
|
| 74 |
+
mel_n_channels, model_hidden_size, model_num_layers, batch_first=True
|
| 75 |
+
)
|
| 76 |
+
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
|
| 77 |
+
self.relu = nn.ReLU()
|
| 78 |
+
|
| 79 |
+
def forward(self, mels):
|
| 80 |
+
self.lstm.flatten_parameters()
|
| 81 |
+
_, (hidden, _) = self.lstm(mels.transpose(-1, -2))
|
| 82 |
+
embeds_raw = self.relu(self.linear(hidden[-1]))
|
| 83 |
+
return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class MELEncoder(nn.Module):
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
in_channels,
|
| 90 |
+
out_channels,
|
| 91 |
+
hidden_channels,
|
| 92 |
+
kernel_size,
|
| 93 |
+
dilation_rate,
|
| 94 |
+
n_layers,
|
| 95 |
+
):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.in_channels = in_channels
|
| 98 |
+
self.out_channels = out_channels
|
| 99 |
+
self.hidden_channels = hidden_channels
|
| 100 |
+
self.kernel_size = kernel_size
|
| 101 |
+
self.dilation_rate = dilation_rate
|
| 102 |
+
self.n_layers = n_layers
|
| 103 |
+
|
| 104 |
+
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
| 105 |
+
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers)
|
| 106 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
| 107 |
+
|
| 108 |
+
def forward(self, x):
|
| 109 |
+
# print(x.shape,x_lengths.shape)
|
| 110 |
+
x = self.pre(x)
|
| 111 |
+
x = self.enc(x)
|
| 112 |
+
x = self.proj(x)
|
| 113 |
+
return x
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class WN(torch.nn.Module):
|
| 117 |
+
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers):
|
| 118 |
+
super(WN, self).__init__()
|
| 119 |
+
assert kernel_size % 2 == 1
|
| 120 |
+
self.hidden_channels = hidden_channels
|
| 121 |
+
self.kernel_size = kernel_size
|
| 122 |
+
self.dilation_rate = dilation_rate
|
| 123 |
+
self.n_layers = n_layers
|
| 124 |
+
|
| 125 |
+
self.in_layers = torch.nn.ModuleList()
|
| 126 |
+
self.res_skip_layers = torch.nn.ModuleList()
|
| 127 |
+
|
| 128 |
+
for i in range(n_layers):
|
| 129 |
+
dilation = dilation_rate**i
|
| 130 |
+
padding = int((kernel_size * dilation - dilation) / 2)
|
| 131 |
+
in_layer = nn.Conv1d(
|
| 132 |
+
hidden_channels,
|
| 133 |
+
2 * hidden_channels,
|
| 134 |
+
kernel_size,
|
| 135 |
+
dilation=dilation,
|
| 136 |
+
padding=padding,
|
| 137 |
+
)
|
| 138 |
+
in_layer = weight_norm(in_layer)
|
| 139 |
+
self.in_layers.append(in_layer)
|
| 140 |
+
|
| 141 |
+
# last one is not necessary
|
| 142 |
+
if i < n_layers - 1:
|
| 143 |
+
res_skip_channels = 2 * hidden_channels
|
| 144 |
+
else:
|
| 145 |
+
res_skip_channels = hidden_channels
|
| 146 |
+
|
| 147 |
+
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
| 148 |
+
res_skip_layer = weight_norm(res_skip_layer, name="weight")
|
| 149 |
+
self.res_skip_layers.append(res_skip_layer)
|
| 150 |
+
|
| 151 |
+
def forward(self, x):
|
| 152 |
+
output = torch.zeros_like(x)
|
| 153 |
+
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
| 154 |
+
|
| 155 |
+
for i in range(self.n_layers):
|
| 156 |
+
x_in = self.in_layers[i](x)
|
| 157 |
+
|
| 158 |
+
acts = fused_add_tanh_sigmoid_multiply(x_in, n_channels_tensor)
|
| 159 |
+
|
| 160 |
+
res_skip_acts = self.res_skip_layers[i](acts)
|
| 161 |
+
if i < self.n_layers - 1:
|
| 162 |
+
res_acts = res_skip_acts[:, : self.hidden_channels, :]
|
| 163 |
+
x = x + res_acts
|
| 164 |
+
output = output + res_skip_acts[:, self.hidden_channels :, :]
|
| 165 |
+
else:
|
| 166 |
+
output = output + res_skip_acts
|
| 167 |
+
return output
|
| 168 |
+
|
| 169 |
+
def remove_weight_norm(self):
|
| 170 |
+
for l in self.in_layers:
|
| 171 |
+
remove_weight_norm(l)
|
| 172 |
+
for l in self.res_skip_layers:
|
| 173 |
+
remove_weight_norm(l)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
@torch.jit.script
|
| 177 |
+
def fused_add_tanh_sigmoid_multiply(input, n_channels):
|
| 178 |
+
n_channels_int = n_channels[0]
|
| 179 |
+
t_act = torch.tanh(input[:, :n_channels_int, :])
|
| 180 |
+
s_act = torch.sigmoid(input[:, n_channels_int:, :])
|
| 181 |
+
acts = t_act * s_act
|
| 182 |
+
return acts
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
if __name__ == "__main__":
|
| 186 |
+
content_enc = torch.randn(3, 192, 100)
|
| 187 |
+
content_mask = torch.ones(3, 1, 100)
|
| 188 |
+
ref_mel = torch.randn(3, 128, 30)
|
| 189 |
+
ref_mask = torch.ones(3, 1, 30)
|
| 190 |
+
model = MRTE()
|
| 191 |
+
out = model(content_enc, content_mask, ref_mel, ref_mask)
|
| 192 |
+
print(out.shape)
|
module/quantize.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Residual vector quantizer implementation."""
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
import math
|
| 11 |
+
import typing as tp
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
|
| 16 |
+
from module.core_vq import ResidualVectorQuantization
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class QuantizedResult:
|
| 21 |
+
quantized: torch.Tensor
|
| 22 |
+
codes: torch.Tensor
|
| 23 |
+
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
|
| 24 |
+
penalty: tp.Optional[torch.Tensor] = None
|
| 25 |
+
metrics: dict = field(default_factory=dict)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ResidualVectorQuantizer(nn.Module):
|
| 29 |
+
"""Residual Vector Quantizer.
|
| 30 |
+
Args:
|
| 31 |
+
dimension (int): Dimension of the codebooks.
|
| 32 |
+
n_q (int): Number of residual vector quantizers used.
|
| 33 |
+
bins (int): Codebook size.
|
| 34 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
| 35 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
| 36 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
| 37 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
| 38 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
| 39 |
+
randomly selected vector from the current batch.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
dimension: int = 256,
|
| 45 |
+
n_q: int = 8,
|
| 46 |
+
bins: int = 1024,
|
| 47 |
+
decay: float = 0.99,
|
| 48 |
+
kmeans_init: bool = True,
|
| 49 |
+
kmeans_iters: int = 50,
|
| 50 |
+
threshold_ema_dead_code: int = 2,
|
| 51 |
+
):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.n_q = n_q
|
| 54 |
+
self.dimension = dimension
|
| 55 |
+
self.bins = bins
|
| 56 |
+
self.decay = decay
|
| 57 |
+
self.kmeans_init = kmeans_init
|
| 58 |
+
self.kmeans_iters = kmeans_iters
|
| 59 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 60 |
+
self.vq = ResidualVectorQuantization(
|
| 61 |
+
dim=self.dimension,
|
| 62 |
+
codebook_size=self.bins,
|
| 63 |
+
num_quantizers=self.n_q,
|
| 64 |
+
decay=self.decay,
|
| 65 |
+
kmeans_init=self.kmeans_init,
|
| 66 |
+
kmeans_iters=self.kmeans_iters,
|
| 67 |
+
threshold_ema_dead_code=self.threshold_ema_dead_code,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def forward(
|
| 71 |
+
self,
|
| 72 |
+
x: torch.Tensor,
|
| 73 |
+
n_q: tp.Optional[int] = None,
|
| 74 |
+
layers: tp.Optional[list] = None,
|
| 75 |
+
) -> QuantizedResult:
|
| 76 |
+
"""Residual vector quantization on the given input tensor.
|
| 77 |
+
Args:
|
| 78 |
+
x (torch.Tensor): Input tensor.
|
| 79 |
+
n_q (int): Number of quantizer used to quantize. Default: All quantizers.
|
| 80 |
+
layers (list): Layer that need to return quantized. Defalt: None.
|
| 81 |
+
Returns:
|
| 82 |
+
QuantizedResult:
|
| 83 |
+
The quantized (or approximately quantized) representation with
|
| 84 |
+
the associated numbert quantizers and layer quantized required to return.
|
| 85 |
+
"""
|
| 86 |
+
n_q = n_q if n_q else self.n_q
|
| 87 |
+
if layers and max(layers) >= n_q:
|
| 88 |
+
raise ValueError(
|
| 89 |
+
f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
|
| 90 |
+
)
|
| 91 |
+
quantized, codes, commit_loss, quantized_list = self.vq(
|
| 92 |
+
x, n_q=n_q, layers=layers
|
| 93 |
+
)
|
| 94 |
+
return quantized, codes, torch.mean(commit_loss), quantized_list
|
| 95 |
+
|
| 96 |
+
def encode(
|
| 97 |
+
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
|
| 98 |
+
) -> torch.Tensor:
|
| 99 |
+
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
| 100 |
+
The RVQ encode method sets the appropriate number of quantizer to use
|
| 101 |
+
and returns indices for each quantizer.
|
| 102 |
+
Args:
|
| 103 |
+
x (torch.Tensor): Input tensor.
|
| 104 |
+
n_q (int): Number of quantizer used to quantize. Default: All quantizers.
|
| 105 |
+
st (int): Start to encode input from which layers. Default: 0.
|
| 106 |
+
"""
|
| 107 |
+
n_q = n_q if n_q else self.n_q
|
| 108 |
+
st = st or 0
|
| 109 |
+
codes = self.vq.encode(x, n_q=n_q, st=st)
|
| 110 |
+
return codes
|
| 111 |
+
|
| 112 |
+
def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor:
|
| 113 |
+
"""Decode the given codes to the quantized representation.
|
| 114 |
+
Args:
|
| 115 |
+
codes (torch.Tensor): Input indices for each quantizer.
|
| 116 |
+
st (int): Start to decode input codes from which layers. Default: 0.
|
| 117 |
+
"""
|
| 118 |
+
quantized = self.vq.decode(codes, st=st)
|
| 119 |
+
return quantized
|
module/transforms.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.nn import functional as F
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
DEFAULT_MIN_BIN_WIDTH = 1e-3
|
| 8 |
+
DEFAULT_MIN_BIN_HEIGHT = 1e-3
|
| 9 |
+
DEFAULT_MIN_DERIVATIVE = 1e-3
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def piecewise_rational_quadratic_transform(
|
| 13 |
+
inputs,
|
| 14 |
+
unnormalized_widths,
|
| 15 |
+
unnormalized_heights,
|
| 16 |
+
unnormalized_derivatives,
|
| 17 |
+
inverse=False,
|
| 18 |
+
tails=None,
|
| 19 |
+
tail_bound=1.0,
|
| 20 |
+
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
| 21 |
+
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
| 22 |
+
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
| 23 |
+
):
|
| 24 |
+
if tails is None:
|
| 25 |
+
spline_fn = rational_quadratic_spline
|
| 26 |
+
spline_kwargs = {}
|
| 27 |
+
else:
|
| 28 |
+
spline_fn = unconstrained_rational_quadratic_spline
|
| 29 |
+
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
|
| 30 |
+
|
| 31 |
+
outputs, logabsdet = spline_fn(
|
| 32 |
+
inputs=inputs,
|
| 33 |
+
unnormalized_widths=unnormalized_widths,
|
| 34 |
+
unnormalized_heights=unnormalized_heights,
|
| 35 |
+
unnormalized_derivatives=unnormalized_derivatives,
|
| 36 |
+
inverse=inverse,
|
| 37 |
+
min_bin_width=min_bin_width,
|
| 38 |
+
min_bin_height=min_bin_height,
|
| 39 |
+
min_derivative=min_derivative,
|
| 40 |
+
**spline_kwargs
|
| 41 |
+
)
|
| 42 |
+
return outputs, logabsdet
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def searchsorted(bin_locations, inputs, eps=1e-6):
|
| 46 |
+
bin_locations[..., -1] += eps
|
| 47 |
+
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def unconstrained_rational_quadratic_spline(
|
| 51 |
+
inputs,
|
| 52 |
+
unnormalized_widths,
|
| 53 |
+
unnormalized_heights,
|
| 54 |
+
unnormalized_derivatives,
|
| 55 |
+
inverse=False,
|
| 56 |
+
tails="linear",
|
| 57 |
+
tail_bound=1.0,
|
| 58 |
+
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
| 59 |
+
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
| 60 |
+
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
| 61 |
+
):
|
| 62 |
+
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
| 63 |
+
outside_interval_mask = ~inside_interval_mask
|
| 64 |
+
|
| 65 |
+
outputs = torch.zeros_like(inputs)
|
| 66 |
+
logabsdet = torch.zeros_like(inputs)
|
| 67 |
+
|
| 68 |
+
if tails == "linear":
|
| 69 |
+
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
|
| 70 |
+
constant = np.log(np.exp(1 - min_derivative) - 1)
|
| 71 |
+
unnormalized_derivatives[..., 0] = constant
|
| 72 |
+
unnormalized_derivatives[..., -1] = constant
|
| 73 |
+
|
| 74 |
+
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
| 75 |
+
logabsdet[outside_interval_mask] = 0
|
| 76 |
+
else:
|
| 77 |
+
raise RuntimeError("{} tails are not implemented.".format(tails))
|
| 78 |
+
|
| 79 |
+
(
|
| 80 |
+
outputs[inside_interval_mask],
|
| 81 |
+
logabsdet[inside_interval_mask],
|
| 82 |
+
) = rational_quadratic_spline(
|
| 83 |
+
inputs=inputs[inside_interval_mask],
|
| 84 |
+
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
|
| 85 |
+
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
|
| 86 |
+
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
|
| 87 |
+
inverse=inverse,
|
| 88 |
+
left=-tail_bound,
|
| 89 |
+
right=tail_bound,
|
| 90 |
+
bottom=-tail_bound,
|
| 91 |
+
top=tail_bound,
|
| 92 |
+
min_bin_width=min_bin_width,
|
| 93 |
+
min_bin_height=min_bin_height,
|
| 94 |
+
min_derivative=min_derivative,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
return outputs, logabsdet
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def rational_quadratic_spline(
|
| 101 |
+
inputs,
|
| 102 |
+
unnormalized_widths,
|
| 103 |
+
unnormalized_heights,
|
| 104 |
+
unnormalized_derivatives,
|
| 105 |
+
inverse=False,
|
| 106 |
+
left=0.0,
|
| 107 |
+
right=1.0,
|
| 108 |
+
bottom=0.0,
|
| 109 |
+
top=1.0,
|
| 110 |
+
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
| 111 |
+
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
| 112 |
+
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
| 113 |
+
):
|
| 114 |
+
if torch.min(inputs) < left or torch.max(inputs) > right:
|
| 115 |
+
raise ValueError("Input to a transform is not within its domain")
|
| 116 |
+
|
| 117 |
+
num_bins = unnormalized_widths.shape[-1]
|
| 118 |
+
|
| 119 |
+
if min_bin_width * num_bins > 1.0:
|
| 120 |
+
raise ValueError("Minimal bin width too large for the number of bins")
|
| 121 |
+
if min_bin_height * num_bins > 1.0:
|
| 122 |
+
raise ValueError("Minimal bin height too large for the number of bins")
|
| 123 |
+
|
| 124 |
+
widths = F.softmax(unnormalized_widths, dim=-1)
|
| 125 |
+
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
| 126 |
+
cumwidths = torch.cumsum(widths, dim=-1)
|
| 127 |
+
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
|
| 128 |
+
cumwidths = (right - left) * cumwidths + left
|
| 129 |
+
cumwidths[..., 0] = left
|
| 130 |
+
cumwidths[..., -1] = right
|
| 131 |
+
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
|
| 132 |
+
|
| 133 |
+
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
|
| 134 |
+
|
| 135 |
+
heights = F.softmax(unnormalized_heights, dim=-1)
|
| 136 |
+
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
| 137 |
+
cumheights = torch.cumsum(heights, dim=-1)
|
| 138 |
+
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
|
| 139 |
+
cumheights = (top - bottom) * cumheights + bottom
|
| 140 |
+
cumheights[..., 0] = bottom
|
| 141 |
+
cumheights[..., -1] = top
|
| 142 |
+
heights = cumheights[..., 1:] - cumheights[..., :-1]
|
| 143 |
+
|
| 144 |
+
if inverse:
|
| 145 |
+
bin_idx = searchsorted(cumheights, inputs)[..., None]
|
| 146 |
+
else:
|
| 147 |
+
bin_idx = searchsorted(cumwidths, inputs)[..., None]
|
| 148 |
+
|
| 149 |
+
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
|
| 150 |
+
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
|
| 151 |
+
|
| 152 |
+
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
|
| 153 |
+
delta = heights / widths
|
| 154 |
+
input_delta = delta.gather(-1, bin_idx)[..., 0]
|
| 155 |
+
|
| 156 |
+
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
|
| 157 |
+
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
|
| 158 |
+
|
| 159 |
+
input_heights = heights.gather(-1, bin_idx)[..., 0]
|
| 160 |
+
|
| 161 |
+
if inverse:
|
| 162 |
+
a = (inputs - input_cumheights) * (
|
| 163 |
+
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
| 164 |
+
) + input_heights * (input_delta - input_derivatives)
|
| 165 |
+
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
|
| 166 |
+
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
| 167 |
+
)
|
| 168 |
+
c = -input_delta * (inputs - input_cumheights)
|
| 169 |
+
|
| 170 |
+
discriminant = b.pow(2) - 4 * a * c
|
| 171 |
+
assert (discriminant >= 0).all()
|
| 172 |
+
|
| 173 |
+
root = (2 * c) / (-b - torch.sqrt(discriminant))
|
| 174 |
+
outputs = root * input_bin_widths + input_cumwidths
|
| 175 |
+
|
| 176 |
+
theta_one_minus_theta = root * (1 - root)
|
| 177 |
+
denominator = input_delta + (
|
| 178 |
+
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
| 179 |
+
* theta_one_minus_theta
|
| 180 |
+
)
|
| 181 |
+
derivative_numerator = input_delta.pow(2) * (
|
| 182 |
+
input_derivatives_plus_one * root.pow(2)
|
| 183 |
+
+ 2 * input_delta * theta_one_minus_theta
|
| 184 |
+
+ input_derivatives * (1 - root).pow(2)
|
| 185 |
+
)
|
| 186 |
+
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
| 187 |
+
|
| 188 |
+
return outputs, -logabsdet
|
| 189 |
+
else:
|
| 190 |
+
theta = (inputs - input_cumwidths) / input_bin_widths
|
| 191 |
+
theta_one_minus_theta = theta * (1 - theta)
|
| 192 |
+
|
| 193 |
+
numerator = input_heights * (
|
| 194 |
+
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
|
| 195 |
+
)
|
| 196 |
+
denominator = input_delta + (
|
| 197 |
+
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
| 198 |
+
* theta_one_minus_theta
|
| 199 |
+
)
|
| 200 |
+
outputs = input_cumheights + numerator / denominator
|
| 201 |
+
|
| 202 |
+
derivative_numerator = input_delta.pow(2) * (
|
| 203 |
+
input_derivatives_plus_one * theta.pow(2)
|
| 204 |
+
+ 2 * input_delta * theta_one_minus_theta
|
| 205 |
+
+ input_derivatives * (1 - theta).pow(2)
|
| 206 |
+
)
|
| 207 |
+
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
| 208 |
+
|
| 209 |
+
return outputs, logabsdet
|
packages.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
ffmpeg
|
pre-requirements.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.5.1
|
| 2 |
+
torchaudio
|
pretrained_models/chinese-hubert-base/config.json
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "/data/docker/liujing04/gpt-vits/chinese-hubert-base",
|
| 3 |
+
"activation_dropout": 0.1,
|
| 4 |
+
"apply_spec_augment": true,
|
| 5 |
+
"architectures": [
|
| 6 |
+
"HubertModel"
|
| 7 |
+
],
|
| 8 |
+
"attention_dropout": 0.1,
|
| 9 |
+
"bos_token_id": 1,
|
| 10 |
+
"classifier_proj_size": 256,
|
| 11 |
+
"conv_bias": false,
|
| 12 |
+
"conv_dim": [
|
| 13 |
+
512,
|
| 14 |
+
512,
|
| 15 |
+
512,
|
| 16 |
+
512,
|
| 17 |
+
512,
|
| 18 |
+
512,
|
| 19 |
+
512
|
| 20 |
+
],
|
| 21 |
+
"conv_kernel": [
|
| 22 |
+
10,
|
| 23 |
+
3,
|
| 24 |
+
3,
|
| 25 |
+
3,
|
| 26 |
+
3,
|
| 27 |
+
2,
|
| 28 |
+
2
|
| 29 |
+
],
|
| 30 |
+
"conv_stride": [
|
| 31 |
+
5,
|
| 32 |
+
2,
|
| 33 |
+
2,
|
| 34 |
+
2,
|
| 35 |
+
2,
|
| 36 |
+
2,
|
| 37 |
+
2
|
| 38 |
+
],
|
| 39 |
+
"ctc_loss_reduction": "sum",
|
| 40 |
+
"ctc_zero_infinity": false,
|
| 41 |
+
"do_stable_layer_norm": false,
|
| 42 |
+
"eos_token_id": 2,
|
| 43 |
+
"feat_extract_activation": "gelu",
|
| 44 |
+
"feat_extract_norm": "group",
|
| 45 |
+
"feat_proj_dropout": 0.0,
|
| 46 |
+
"feat_proj_layer_norm": true,
|
| 47 |
+
"final_dropout": 0.1,
|
| 48 |
+
"hidden_act": "gelu",
|
| 49 |
+
"hidden_dropout": 0.1,
|
| 50 |
+
"hidden_size": 768,
|
| 51 |
+
"initializer_range": 0.02,
|
| 52 |
+
"intermediate_size": 3072,
|
| 53 |
+
"layer_norm_eps": 1e-05,
|
| 54 |
+
"layerdrop": 0.1,
|
| 55 |
+
"mask_feature_length": 10,
|
| 56 |
+
"mask_feature_min_masks": 0,
|
| 57 |
+
"mask_feature_prob": 0.0,
|
| 58 |
+
"mask_time_length": 10,
|
| 59 |
+
"mask_time_min_masks": 2,
|
| 60 |
+
"mask_time_prob": 0.05,
|
| 61 |
+
"model_type": "hubert",
|
| 62 |
+
"num_attention_heads": 12,
|
| 63 |
+
"num_conv_pos_embedding_groups": 16,
|
| 64 |
+
"num_conv_pos_embeddings": 128,
|
| 65 |
+
"num_feat_extract_layers": 7,
|
| 66 |
+
"num_hidden_layers": 12,
|
| 67 |
+
"pad_token_id": 0,
|
| 68 |
+
"torch_dtype": "float16",
|
| 69 |
+
"transformers_version": "4.30.2",
|
| 70 |
+
"use_weighted_layer_sum": false,
|
| 71 |
+
"vocab_size": 32
|
| 72 |
+
}
|
pretrained_models/chinese-hubert-base/preprocessor_config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"do_normalize": true,
|
| 3 |
+
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
|
| 4 |
+
"feature_size": 1,
|
| 5 |
+
"padding_side": "right",
|
| 6 |
+
"padding_value": 0,
|
| 7 |
+
"return_attention_mask": false,
|
| 8 |
+
"sampling_rate": 16000
|
| 9 |
+
}
|
pretrained_models/chinese-roberta-wwm-ext-large/config.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"BertForMaskedLM"
|
| 5 |
+
],
|
| 6 |
+
"attention_probs_dropout_prob": 0.1,
|
| 7 |
+
"bos_token_id": 0,
|
| 8 |
+
"classifier_dropout": null,
|
| 9 |
+
"directionality": "bidi",
|
| 10 |
+
"eos_token_id": 2,
|
| 11 |
+
"hidden_act": "gelu",
|
| 12 |
+
"hidden_dropout_prob": 0.1,
|
| 13 |
+
"hidden_size": 1024,
|
| 14 |
+
"initializer_range": 0.02,
|
| 15 |
+
"intermediate_size": 4096,
|
| 16 |
+
"layer_norm_eps": 1e-12,
|
| 17 |
+
"max_position_embeddings": 512,
|
| 18 |
+
"model_type": "bert",
|
| 19 |
+
"num_attention_heads": 16,
|
| 20 |
+
"num_hidden_layers": 24,
|
| 21 |
+
"output_past": true,
|
| 22 |
+
"pad_token_id": 0,
|
| 23 |
+
"pooler_fc_size": 768,
|
| 24 |
+
"pooler_num_attention_heads": 12,
|
| 25 |
+
"pooler_num_fc_layers": 3,
|
| 26 |
+
"pooler_size_per_head": 128,
|
| 27 |
+
"pooler_type": "first_token_transform",
|
| 28 |
+
"position_embedding_type": "absolute",
|
| 29 |
+
"torch_dtype": "float16",
|
| 30 |
+
"transformers_version": "4.30.2",
|
| 31 |
+
"type_vocab_size": 2,
|
| 32 |
+
"use_cache": true,
|
| 33 |
+
"vocab_size": 21128
|
| 34 |
+
}
|
pretrained_models/chinese-roberta-wwm-ext-large/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
process_ckpt.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import traceback
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
from time import time as ttime
|
| 4 |
+
import shutil,os
|
| 5 |
+
import torch
|
| 6 |
+
from tools.i18n.i18n import I18nAuto
|
| 7 |
+
|
| 8 |
+
i18n = I18nAuto()
|
| 9 |
+
|
| 10 |
+
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
|
| 11 |
+
dir=os.path.dirname(path)
|
| 12 |
+
name=os.path.basename(path)
|
| 13 |
+
tmp_path="%s.pth"%(ttime())
|
| 14 |
+
torch.save(fea,tmp_path)
|
| 15 |
+
shutil.move(tmp_path,"%s/%s"%(dir,name))
|
| 16 |
+
|
| 17 |
+
def savee(ckpt, name, epoch, steps, hps):
|
| 18 |
+
try:
|
| 19 |
+
opt = OrderedDict()
|
| 20 |
+
opt["weight"] = {}
|
| 21 |
+
for key in ckpt.keys():
|
| 22 |
+
if "enc_q" in key:
|
| 23 |
+
continue
|
| 24 |
+
opt["weight"][key] = ckpt[key].half()
|
| 25 |
+
opt["config"] = hps
|
| 26 |
+
opt["info"] = "%sepoch_%siteration" % (epoch, steps)
|
| 27 |
+
# torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
|
| 28 |
+
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
|
| 29 |
+
return "Success."
|
| 30 |
+
except:
|
| 31 |
+
return traceback.format_exc()
|
requirements.txt
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy<2.0
|
| 2 |
+
scipy>=1.11.3
|
| 3 |
+
tensorboard==2.15.1
|
| 4 |
+
librosa==0.9.2
|
| 5 |
+
numba==0.56.4
|
| 6 |
+
pytorch-lightning>=2.4
|
| 7 |
+
ffmpeg-python==0.2.0
|
| 8 |
+
onnxruntime-gpu
|
| 9 |
+
tqdm==4.66.4
|
| 10 |
+
cn2an==0.5.22
|
| 11 |
+
pypinyin==0.50.0
|
| 12 |
+
pyopenjtalk==0.4.1
|
| 13 |
+
g2p_en==2.1.0
|
| 14 |
+
sentencepiece==0.1.99
|
| 15 |
+
transformers==4.43.0
|
| 16 |
+
chardet==3.0.4
|
| 17 |
+
PyYAML==6.0.1
|
| 18 |
+
psutil==5.9.7
|
| 19 |
+
jieba_fast==0.53
|
| 20 |
+
jieba==0.42.1
|
| 21 |
+
https://hf-mirror.com/lj1995/GPT-SoVITS-windows-package/resolve/main/langsegment-0.3.5-py3-none-any.whl?download=true
|
| 22 |
+
wordsegment==1.3.1
|
| 23 |
+
rotary_embedding_torch==0.6.4
|
| 24 |
+
spaces
|
| 25 |
+
pyjyutping==1.0.0
|
| 26 |
+
g2pk2==0.0.3
|
| 27 |
+
ko_pron==1.3
|
| 28 |
+
opencc==1.1.0
|
| 29 |
+
python_mecab_ko==1.3.7
|
| 30 |
+
pydantic==2.8.2
|
| 31 |
+
torchmetrics<=1.5
|
| 32 |
+
nltk==3.8.1
|
| 33 |
+
fast_langdetect==0.3.1
|
| 34 |
+
split_lang==2.1.0
|
| 35 |
+
ToJyutping==3.2.0
|
| 36 |
+
https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
sv.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys,os,torch
|
| 2 |
+
sys.path.append(f"{os.getcwd()}/eres2net")
|
| 3 |
+
sv_path = "pretrained_models/sv/pretrained_eres2netv2w24s4ep4.ckpt"
|
| 4 |
+
from ERes2NetV2 import ERes2NetV2
|
| 5 |
+
import kaldi as Kaldi
|
| 6 |
+
class SV:
|
| 7 |
+
def __init__(self,device,is_half):
|
| 8 |
+
pretrained_state = torch.load(sv_path, map_location='cpu', weights_only=False)
|
| 9 |
+
embedding_model = ERes2NetV2(baseWidth=24,scale=4,expansion=4)
|
| 10 |
+
embedding_model.load_state_dict(pretrained_state)
|
| 11 |
+
embedding_model.eval()
|
| 12 |
+
self.embedding_model=embedding_model
|
| 13 |
+
if is_half == False:
|
| 14 |
+
self.embedding_model=self.embedding_model.to(device)
|
| 15 |
+
else:
|
| 16 |
+
self.embedding_model=self.embedding_model.half().to(device)
|
| 17 |
+
self.is_half=is_half
|
| 18 |
+
|
| 19 |
+
def compute_embedding3(self,wav):
|
| 20 |
+
with torch.no_grad():
|
| 21 |
+
if self.is_half==True:wav=wav.half()
|
| 22 |
+
feat = torch.stack([Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav])
|
| 23 |
+
sv_emb = self.embedding_model.forward3(feat)
|
| 24 |
+
return sv_emb
|
text/.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
G2PWModel
|
| 2 |
+
__pycache__
|
| 3 |
+
*.zip
|
text/LangSegmenter/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .langsegmenter import LangSegmenter
|
text/LangSegmenter/langsegmenter.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
# jieba静音
|
| 5 |
+
import jieba
|
| 6 |
+
jieba.setLogLevel(logging.CRITICAL)
|
| 7 |
+
|
| 8 |
+
# 更改fast_langdetect大模型位置
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import fast_langdetect
|
| 11 |
+
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(fast_langdetect.infer.LangDetectConfig(cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"))
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
from split_lang import LangSplitter
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def full_en(text):
|
| 18 |
+
pattern = r'^(?=.*[A-Za-z])[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$'
|
| 19 |
+
return bool(re.match(pattern, text))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def full_cjk(text):
|
| 23 |
+
# 来自wiki
|
| 24 |
+
cjk_ranges = [
|
| 25 |
+
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
| 26 |
+
(0x3400, 0x4DB5), # CJK Extension A
|
| 27 |
+
(0x20000, 0x2A6DD), # CJK Extension B
|
| 28 |
+
(0x2A700, 0x2B73F), # CJK Extension C
|
| 29 |
+
(0x2B740, 0x2B81F), # CJK Extension D
|
| 30 |
+
(0x2B820, 0x2CEAF), # CJK Extension E
|
| 31 |
+
(0x2CEB0, 0x2EBEF), # CJK Extension F
|
| 32 |
+
(0x30000, 0x3134A), # CJK Extension G
|
| 33 |
+
(0x31350, 0x323AF), # CJK Extension H
|
| 34 |
+
(0x2EBF0, 0x2EE5D), # CJK Extension H
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
pattern = r'[0-9、-〜。!?.!?… /]+$'
|
| 38 |
+
|
| 39 |
+
cjk_text = ""
|
| 40 |
+
for char in text:
|
| 41 |
+
code_point = ord(char)
|
| 42 |
+
in_cjk = any(start <= code_point <= end for start, end in cjk_ranges)
|
| 43 |
+
if in_cjk or re.match(pattern, char):
|
| 44 |
+
cjk_text += char
|
| 45 |
+
return cjk_text
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def split_jako(tag_lang,item):
|
| 49 |
+
if tag_lang == "ja":
|
| 50 |
+
pattern = r"([\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]+(?:[0-9、-〜。!?.!?… ]+[\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]*)*)"
|
| 51 |
+
else:
|
| 52 |
+
pattern = r"([\u1100-\u11FF\u3130-\u318F\uAC00-\uD7AF]+(?:[0-9、-〜。!?.!?… ]+[\u1100-\u11FF\u3130-\u318F\uAC00-\uD7AF]*)*)"
|
| 53 |
+
|
| 54 |
+
lang_list: list[dict] = []
|
| 55 |
+
tag = 0
|
| 56 |
+
for match in re.finditer(pattern, item['text']):
|
| 57 |
+
if match.start() > tag:
|
| 58 |
+
lang_list.append({'lang':item['lang'],'text':item['text'][tag:match.start()]})
|
| 59 |
+
|
| 60 |
+
tag = match.end()
|
| 61 |
+
lang_list.append({'lang':tag_lang,'text':item['text'][match.start():match.end()]})
|
| 62 |
+
|
| 63 |
+
if tag < len(item['text']):
|
| 64 |
+
lang_list.append({'lang':item['lang'],'text':item['text'][tag:len(item['text'])]})
|
| 65 |
+
|
| 66 |
+
return lang_list
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def merge_lang(lang_list, item):
|
| 70 |
+
if lang_list and item['lang'] == lang_list[-1]['lang']:
|
| 71 |
+
lang_list[-1]['text'] += item['text']
|
| 72 |
+
else:
|
| 73 |
+
lang_list.append(item)
|
| 74 |
+
return lang_list
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class LangSegmenter():
|
| 78 |
+
# 默认过滤器, 基于gsv目前四种语言
|
| 79 |
+
DEFAULT_LANG_MAP = {
|
| 80 |
+
"zh": "zh",
|
| 81 |
+
"yue": "zh", # 粤语
|
| 82 |
+
"wuu": "zh", # 吴语
|
| 83 |
+
"zh-cn": "zh",
|
| 84 |
+
"zh-tw": "x", # 繁体设置为x
|
| 85 |
+
"ko": "ko",
|
| 86 |
+
"ja": "ja",
|
| 87 |
+
"en": "en",
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def getTexts(text):
|
| 92 |
+
lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP)
|
| 93 |
+
substr = lang_splitter.split_by_lang(text=text)
|
| 94 |
+
|
| 95 |
+
lang_list: list[dict] = []
|
| 96 |
+
|
| 97 |
+
for _, item in enumerate(substr):
|
| 98 |
+
dict_item = {'lang':item.lang,'text':item.text}
|
| 99 |
+
|
| 100 |
+
# 处理短英文被识别为其他语言的问题
|
| 101 |
+
if full_en(dict_item['text']):
|
| 102 |
+
dict_item['lang'] = 'en'
|
| 103 |
+
lang_list = merge_lang(lang_list,dict_item)
|
| 104 |
+
continue
|
| 105 |
+
|
| 106 |
+
# 处理非日语夹日文的问题(不包含CJK)
|
| 107 |
+
ja_list: list[dict] = []
|
| 108 |
+
if dict_item['lang'] != 'ja':
|
| 109 |
+
ja_list = split_jako('ja',dict_item)
|
| 110 |
+
|
| 111 |
+
if not ja_list:
|
| 112 |
+
ja_list.append(dict_item)
|
| 113 |
+
|
| 114 |
+
# 处理非韩语夹韩语的问题(不包含CJK)
|
| 115 |
+
ko_list: list[dict] = []
|
| 116 |
+
temp_list: list[dict] = []
|
| 117 |
+
for _, ko_item in enumerate(ja_list):
|
| 118 |
+
if ko_item["lang"] != 'ko':
|
| 119 |
+
ko_list = split_jako('ko',ko_item)
|
| 120 |
+
|
| 121 |
+
if ko_list:
|
| 122 |
+
temp_list.extend(ko_list)
|
| 123 |
+
else:
|
| 124 |
+
temp_list.append(ko_item)
|
| 125 |
+
|
| 126 |
+
# 未存在非日韩文夹日韩文
|
| 127 |
+
if len(temp_list) == 1:
|
| 128 |
+
# 未知语言检查是否为CJK
|
| 129 |
+
if dict_item['lang'] == 'x':
|
| 130 |
+
cjk_text = full_cjk(dict_item['text'])
|
| 131 |
+
if cjk_text:
|
| 132 |
+
dict_item = {'lang':'zh','text':cjk_text}
|
| 133 |
+
lang_list = merge_lang(lang_list,dict_item)
|
| 134 |
+
else:
|
| 135 |
+
lang_list = merge_lang(lang_list,dict_item)
|
| 136 |
+
continue
|
| 137 |
+
else:
|
| 138 |
+
lang_list = merge_lang(lang_list,dict_item)
|
| 139 |
+
continue
|
| 140 |
+
|
| 141 |
+
# 存在非日韩文夹日韩文
|
| 142 |
+
for _, temp_item in enumerate(temp_list):
|
| 143 |
+
# 未知语言检查是否为CJK
|
| 144 |
+
if temp_item['lang'] == 'x':
|
| 145 |
+
cjk_text = full_cjk(dict_item['text'])
|
| 146 |
+
if cjk_text:
|
| 147 |
+
dict_item = {'lang':'zh','text':cjk_text}
|
| 148 |
+
lang_list = merge_lang(lang_list,dict_item)
|
| 149 |
+
else:
|
| 150 |
+
lang_list = merge_lang(lang_list,dict_item)
|
| 151 |
+
else:
|
| 152 |
+
lang_list = merge_lang(lang_list,temp_item)
|
| 153 |
+
|
| 154 |
+
temp_list = lang_list
|
| 155 |
+
lang_list = []
|
| 156 |
+
for _, temp_item in enumerate(temp_list):
|
| 157 |
+
if temp_item['lang'] == 'x':
|
| 158 |
+
if lang_list:
|
| 159 |
+
temp_item['lang'] = lang_list[-1]['lang']
|
| 160 |
+
elif len(temp_list) > 1:
|
| 161 |
+
temp_item['lang'] = temp_list[1]['lang']
|
| 162 |
+
else:
|
| 163 |
+
temp_item['lang'] = 'zh'
|
| 164 |
+
|
| 165 |
+
lang_list = merge_lang(lang_list,temp_item)
|
| 166 |
+
|
| 167 |
+
return lang_list
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
if __name__ == "__main__":
|
| 171 |
+
text = "MyGO?,你也喜欢まいご吗?"
|
| 172 |
+
print(LangSegmenter.getTexts(text))
|
| 173 |
+
|
| 174 |
+
text = "ねえ、知ってる?最近、僕は天文学を勉強してるんだ。君の瞳が星空みたいにキラキラしてるからさ。"
|
| 175 |
+
print(LangSegmenter.getTexts(text))
|