Spaces:
Running
on
T4
Running
on
T4
j
commited on
Commit
·
f98c92f
1
Parent(s):
af25078
update to upstream chatterbox implementation, fixes token filtering/clamping
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +1 -0
- chatterbox/src/chatterbox/__init__.py +15 -0
- chatterbox/src/chatterbox/__pycache__/__init__.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/__pycache__/tts.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/__pycache__/utils.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/__pycache__/vc.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/__init__.py +0 -0
- chatterbox/src/chatterbox/models/s3gen/__pycache__/__init__.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/__pycache__/const.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/__pycache__/decoder.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/__pycache__/f0_predictor.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/__pycache__/flow.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/__pycache__/flow_matching.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/__pycache__/hifigan.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/__pycache__/s3gen.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/__pycache__/xvector.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/configs.py +10 -0
- chatterbox/src/chatterbox/models/s3gen/flow.py +89 -41
- chatterbox/src/chatterbox/models/s3gen/flow_matching.py +1 -11
- chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/decoder.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/flow_matching.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/transformer.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/s3gen.py +2 -9
- chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/__init__.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/activation.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/attention.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/convolution.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/embedding.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/encoder_layer.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/subsampling.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/upsample_encoder.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/utils/__pycache__/class_utils.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/utils/__pycache__/mask.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/utils/__pycache__/mel.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/utils/mel.py +8 -4
- chatterbox/src/chatterbox/models/s3tokenizer/__pycache__/__init__.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3tokenizer/__pycache__/s3tokenizer.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/t3/__pycache__/__init__.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/t3/__pycache__/llama_configs.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/t3/__pycache__/t3.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/t3/inference/__pycache__/alignment_stream_analyzer.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py +56 -32
- chatterbox/src/chatterbox/models/t3/modules/__pycache__/cond_enc.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/t3/modules/__pycache__/learned_pos_emb.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/t3/modules/__pycache__/perceiver.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/t3/modules/__pycache__/t3_config.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/t3/modules/t3_config.py +30 -16
- chatterbox/src/chatterbox/models/t3/t3.py +56 -34
README.md
CHANGED
|
@@ -8,6 +8,7 @@ sdk_version: 5.29.0
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
short_description: Expressive Zeroshot TTS
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
short_description: Expressive Zeroshot TTS
|
| 11 |
+
python_version: 3.10
|
| 12 |
---
|
| 13 |
|
| 14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
chatterbox/src/chatterbox/__init__.py
CHANGED
|
@@ -1,2 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from .tts import ChatterboxTTS
|
| 2 |
from .vc import ChatterboxVC
|
|
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
from importlib.metadata import version, PackageNotFoundError
|
| 3 |
+
try:
|
| 4 |
+
__version__ = version("chatterbox-tts")
|
| 5 |
+
except PackageNotFoundError:
|
| 6 |
+
__version__ = "0.1.4" # Default fallback version
|
| 7 |
+
except ImportError:
|
| 8 |
+
from importlib_metadata import version, PackageNotFoundError # For Python <3.8
|
| 9 |
+
try:
|
| 10 |
+
__version__ = version("chatterbox-tts")
|
| 11 |
+
except PackageNotFoundError:
|
| 12 |
+
__version__ = "0.1.4"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
from .tts import ChatterboxTTS
|
| 16 |
from .vc import ChatterboxVC
|
| 17 |
+
from .mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
|
chatterbox/src/chatterbox/__pycache__/__init__.cpython-311.pyc
DELETED
|
Binary file (275 Bytes)
|
|
|
chatterbox/src/chatterbox/__pycache__/tts.cpython-311.pyc
DELETED
|
Binary file (13.3 kB)
|
|
|
chatterbox/src/chatterbox/__pycache__/utils.cpython-311.pyc
DELETED
|
Binary file (858 Bytes)
|
|
|
chatterbox/src/chatterbox/__pycache__/vc.cpython-311.pyc
DELETED
|
Binary file (5.44 kB)
|
|
|
chatterbox/src/chatterbox/models/__init__.py
ADDED
|
File without changes
|
chatterbox/src/chatterbox/models/s3gen/__pycache__/__init__.cpython-311.pyc
DELETED
|
Binary file (294 Bytes)
|
|
|
chatterbox/src/chatterbox/models/s3gen/__pycache__/const.cpython-311.pyc
DELETED
|
Binary file (190 Bytes)
|
|
|
chatterbox/src/chatterbox/models/s3gen/__pycache__/decoder.cpython-311.pyc
DELETED
|
Binary file (16.9 kB)
|
|
|
chatterbox/src/chatterbox/models/s3gen/__pycache__/f0_predictor.cpython-311.pyc
DELETED
|
Binary file (2.7 kB)
|
|
|
chatterbox/src/chatterbox/models/s3gen/__pycache__/flow.cpython-311.pyc
DELETED
|
Binary file (13.7 kB)
|
|
|
chatterbox/src/chatterbox/models/s3gen/__pycache__/flow_matching.cpython-311.pyc
DELETED
|
Binary file (13.3 kB)
|
|
|
chatterbox/src/chatterbox/models/s3gen/__pycache__/hifigan.cpython-311.pyc
DELETED
|
Binary file (26.3 kB)
|
|
|
chatterbox/src/chatterbox/models/s3gen/__pycache__/s3gen.cpython-311.pyc
DELETED
|
Binary file (13.7 kB)
|
|
|
chatterbox/src/chatterbox/models/s3gen/__pycache__/xvector.cpython-311.pyc
DELETED
|
Binary file (24 kB)
|
|
|
chatterbox/src/chatterbox/models/s3gen/configs.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..utils import AttrDict
|
| 2 |
+
|
| 3 |
+
CFM_PARAMS = AttrDict({
|
| 4 |
+
"sigma_min": 1e-06,
|
| 5 |
+
"solver": "euler",
|
| 6 |
+
"t_scheduler": "cosine",
|
| 7 |
+
"training_cfg_rate": 0.2,
|
| 8 |
+
"inference_cfg_rate": 0.7,
|
| 9 |
+
"reg_loss_type": "l1"
|
| 10 |
+
})
|
chatterbox/src/chatterbox/models/s3gen/flow.py
CHANGED
|
@@ -14,32 +14,54 @@
|
|
| 14 |
import logging
|
| 15 |
import random
|
| 16 |
from typing import Dict, Optional
|
|
|
|
|
|
|
| 17 |
import torch
|
| 18 |
import torch.nn as nn
|
| 19 |
from torch.nn import functional as F
|
| 20 |
-
from omegaconf import DictConfig
|
| 21 |
from .utils.mask import make_pad_mask
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
class MaskedDiffWithXvec(torch.nn.Module):
|
| 25 |
-
def __init__(
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
super().__init__()
|
| 44 |
self.input_size = input_size
|
| 45 |
self.output_size = output_size
|
|
@@ -74,7 +96,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
| 74 |
|
| 75 |
# concat text and prompt_text
|
| 76 |
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
| 77 |
-
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
| 78 |
|
| 79 |
# text encode
|
| 80 |
h, h_lengths = self.encoder(token, token_len)
|
|
@@ -124,7 +146,13 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
| 124 |
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
| 125 |
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
| 126 |
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
# text encode
|
| 130 |
h, h_lengths = self.encoder(token, token_len)
|
|
@@ -153,25 +181,45 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
| 153 |
|
| 154 |
|
| 155 |
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
| 156 |
-
def __init__(
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
super().__init__()
|
| 176 |
self.input_size = input_size
|
| 177 |
self.output_size = output_size
|
|
@@ -215,7 +263,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|
| 215 |
# concat text and prompt_text
|
| 216 |
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
| 217 |
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
| 218 |
-
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
| 219 |
|
| 220 |
# text encode
|
| 221 |
h, h_lengths = self.encoder(token, token_len)
|
|
|
|
| 14 |
import logging
|
| 15 |
import random
|
| 16 |
from typing import Dict, Optional
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
import torch
|
| 20 |
import torch.nn as nn
|
| 21 |
from torch.nn import functional as F
|
|
|
|
| 22 |
from .utils.mask import make_pad_mask
|
| 23 |
+
from .configs import CFM_PARAMS
|
| 24 |
|
| 25 |
|
| 26 |
class MaskedDiffWithXvec(torch.nn.Module):
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
input_size: int = 512,
|
| 30 |
+
output_size: int = 80,
|
| 31 |
+
spk_embed_dim: int = 192,
|
| 32 |
+
output_type: str = "mel",
|
| 33 |
+
vocab_size: int = 4096,
|
| 34 |
+
input_frame_rate: int = 50,
|
| 35 |
+
only_mask_loss: bool = True,
|
| 36 |
+
encoder: torch.nn.Module = None,
|
| 37 |
+
length_regulator: torch.nn.Module = None,
|
| 38 |
+
decoder: torch.nn.Module = None,
|
| 39 |
+
decoder_conf: Dict = {
|
| 40 |
+
'in_channels': 240,
|
| 41 |
+
'out_channel': 80,
|
| 42 |
+
'spk_emb_dim': 80,
|
| 43 |
+
'n_spks': 1,
|
| 44 |
+
'cfm_params': CFM_PARAMS,
|
| 45 |
+
'decoder_params': {
|
| 46 |
+
'channels': [256, 256],
|
| 47 |
+
'dropout': 0.0,
|
| 48 |
+
'attention_head_dim': 64,
|
| 49 |
+
'n_blocks': 4,
|
| 50 |
+
'num_mid_blocks': 12,
|
| 51 |
+
'num_heads': 8,
|
| 52 |
+
'act_fn': 'gelu',
|
| 53 |
+
}
|
| 54 |
+
},
|
| 55 |
+
mel_feat_conf: Dict = {
|
| 56 |
+
'n_fft': 1024,
|
| 57 |
+
'num_mels': 80,
|
| 58 |
+
'sampling_rate': 22050,
|
| 59 |
+
'hop_size': 256,
|
| 60 |
+
'win_size': 1024,
|
| 61 |
+
'fmin': 0,
|
| 62 |
+
'fmax': 8000
|
| 63 |
+
}
|
| 64 |
+
):
|
| 65 |
super().__init__()
|
| 66 |
self.input_size = input_size
|
| 67 |
self.output_size = output_size
|
|
|
|
| 96 |
|
| 97 |
# concat text and prompt_text
|
| 98 |
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
| 99 |
+
token = self.input_embedding(torch.clamp(token, min=0, max=self.input_embedding.num_embeddings-1)) * mask
|
| 100 |
|
| 101 |
# text encode
|
| 102 |
h, h_lengths = self.encoder(token, token_len)
|
|
|
|
| 146 |
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
| 147 |
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
| 148 |
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
| 149 |
+
|
| 150 |
+
# Check for out-of-bounds token IDs
|
| 151 |
+
vocab_size = self.input_embedding.num_embeddings
|
| 152 |
+
if token.max() >= vocab_size or token.min() < 0:
|
| 153 |
+
logging.warning(f"S3Gen: Token IDs out of bounds: min={token.min().item()}, max={token.max().item()}, vocab_size={vocab_size}")
|
| 154 |
+
|
| 155 |
+
token = self.input_embedding(torch.clamp(token, min=0, max=vocab_size-1)) * mask
|
| 156 |
|
| 157 |
# text encode
|
| 158 |
h, h_lengths = self.encoder(token, token_len)
|
|
|
|
| 181 |
|
| 182 |
|
| 183 |
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
| 184 |
+
def __init__(
|
| 185 |
+
self,
|
| 186 |
+
input_size: int = 512,
|
| 187 |
+
output_size: int = 80,
|
| 188 |
+
spk_embed_dim: int = 192,
|
| 189 |
+
output_type: str = "mel",
|
| 190 |
+
vocab_size: int = 6561,
|
| 191 |
+
input_frame_rate: int = 25,
|
| 192 |
+
only_mask_loss: bool = True,
|
| 193 |
+
token_mel_ratio: int = 2,
|
| 194 |
+
pre_lookahead_len: int = 3,
|
| 195 |
+
encoder: torch.nn.Module = None,
|
| 196 |
+
decoder: torch.nn.Module = None,
|
| 197 |
+
decoder_conf: Dict = {
|
| 198 |
+
'in_channels': 240,
|
| 199 |
+
'out_channel': 80,
|
| 200 |
+
'spk_emb_dim': 80,
|
| 201 |
+
'n_spks': 1,
|
| 202 |
+
'cfm_params': CFM_PARAMS,
|
| 203 |
+
'decoder_params': {
|
| 204 |
+
'channels': [256, 256],
|
| 205 |
+
'dropout': 0.0,
|
| 206 |
+
'attention_head_dim': 64,
|
| 207 |
+
'n_blocks': 4,
|
| 208 |
+
'num_mid_blocks': 12,
|
| 209 |
+
'num_heads': 8,
|
| 210 |
+
'act_fn': 'gelu',
|
| 211 |
+
}
|
| 212 |
+
},
|
| 213 |
+
mel_feat_conf: Dict = {
|
| 214 |
+
'n_fft': 1024,
|
| 215 |
+
'num_mels': 80,
|
| 216 |
+
'sampling_rate': 22050,
|
| 217 |
+
'hop_size': 256,
|
| 218 |
+
'win_size': 1024,
|
| 219 |
+
'fmin': 0,
|
| 220 |
+
'fmax': 8000
|
| 221 |
+
}
|
| 222 |
+
):
|
| 223 |
super().__init__()
|
| 224 |
self.input_size = input_size
|
| 225 |
self.output_size = output_size
|
|
|
|
| 263 |
# concat text and prompt_text
|
| 264 |
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
| 265 |
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
| 266 |
+
token = self.input_embedding(torch.clamp(token, min=0, max=self.input_embedding.num_embeddings-1)) * mask
|
| 267 |
|
| 268 |
# text encode
|
| 269 |
h, h_lengths = self.encoder(token, token_len)
|
chatterbox/src/chatterbox/models/s3gen/flow_matching.py
CHANGED
|
@@ -15,17 +15,7 @@ import threading
|
|
| 15 |
import torch
|
| 16 |
import torch.nn.functional as F
|
| 17 |
from .matcha.flow_matching import BASECFM
|
| 18 |
-
from
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
CFM_PARAMS = OmegaConf.create({
|
| 22 |
-
"sigma_min": 1e-06,
|
| 23 |
-
"solver": "euler",
|
| 24 |
-
"t_scheduler": "cosine",
|
| 25 |
-
"training_cfg_rate": 0.2,
|
| 26 |
-
"inference_cfg_rate": 0.7,
|
| 27 |
-
"reg_loss_type": "l1"
|
| 28 |
-
})
|
| 29 |
|
| 30 |
|
| 31 |
class ConditionalCFM(BASECFM):
|
|
|
|
| 15 |
import torch
|
| 16 |
import torch.nn.functional as F
|
| 17 |
from .matcha.flow_matching import BASECFM
|
| 18 |
+
from .configs import CFM_PARAMS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
class ConditionalCFM(BASECFM):
|
chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/decoder.cpython-311.pyc
DELETED
|
Binary file (21.3 kB)
|
|
|
chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/flow_matching.cpython-311.pyc
DELETED
|
Binary file (6.46 kB)
|
|
|
chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/transformer.cpython-311.pyc
DELETED
|
Binary file (14.7 kB)
|
|
|
chatterbox/src/chatterbox/models/s3gen/s3gen.py
CHANGED
|
@@ -19,7 +19,6 @@ import torch
|
|
| 19 |
import torchaudio as ta
|
| 20 |
from functools import lru_cache
|
| 21 |
from typing import Optional
|
| 22 |
-
from omegaconf import DictConfig
|
| 23 |
|
| 24 |
from ..s3tokenizer import S3_SR, SPEECH_VOCAB_SIZE, S3Tokenizer
|
| 25 |
from .const import S3GEN_SR
|
|
@@ -31,6 +30,7 @@ from .hifigan import HiFTGenerator
|
|
| 31 |
from .transformer.upsample_encoder import UpsampleConformerEncoder
|
| 32 |
from .flow_matching import CausalConditionalCFM
|
| 33 |
from .decoder import ConditionalDecoder
|
|
|
|
| 34 |
|
| 35 |
|
| 36 |
def drop_invalid_tokens(x):
|
|
@@ -85,14 +85,7 @@ class S3Token2Mel(torch.nn.Module):
|
|
| 85 |
num_heads=8,
|
| 86 |
act_fn='gelu',
|
| 87 |
)
|
| 88 |
-
cfm_params =
|
| 89 |
-
"sigma_min": 1e-06,
|
| 90 |
-
"solver": 'euler',
|
| 91 |
-
"t_scheduler": 'cosine',
|
| 92 |
-
"training_cfg_rate": 0.2,
|
| 93 |
-
"inference_cfg_rate": 0.7,
|
| 94 |
-
"reg_loss_type": 'l1',
|
| 95 |
-
})
|
| 96 |
decoder = CausalConditionalCFM(
|
| 97 |
spk_emb_dim=80,
|
| 98 |
cfm_params=cfm_params,
|
|
|
|
| 19 |
import torchaudio as ta
|
| 20 |
from functools import lru_cache
|
| 21 |
from typing import Optional
|
|
|
|
| 22 |
|
| 23 |
from ..s3tokenizer import S3_SR, SPEECH_VOCAB_SIZE, S3Tokenizer
|
| 24 |
from .const import S3GEN_SR
|
|
|
|
| 30 |
from .transformer.upsample_encoder import UpsampleConformerEncoder
|
| 31 |
from .flow_matching import CausalConditionalCFM
|
| 32 |
from .decoder import ConditionalDecoder
|
| 33 |
+
from .configs import CFM_PARAMS
|
| 34 |
|
| 35 |
|
| 36 |
def drop_invalid_tokens(x):
|
|
|
|
| 85 |
num_heads=8,
|
| 86 |
act_fn='gelu',
|
| 87 |
)
|
| 88 |
+
cfm_params = CFM_PARAMS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
decoder = CausalConditionalCFM(
|
| 90 |
spk_emb_dim=80,
|
| 91 |
cfm_params=cfm_params,
|
chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/__init__.cpython-311.pyc
DELETED
|
Binary file (190 Bytes)
|
|
|
chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/activation.cpython-311.pyc
DELETED
|
Binary file (3.58 kB)
|
|
|
chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/attention.cpython-311.pyc
DELETED
|
Binary file (15.7 kB)
|
|
|
chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/convolution.cpython-311.pyc
DELETED
|
Binary file (5.54 kB)
|
|
|
chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/embedding.cpython-311.pyc
DELETED
|
Binary file (17.3 kB)
|
|
|
chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/encoder_layer.cpython-311.pyc
DELETED
|
Binary file (11.2 kB)
|
|
|
chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc
DELETED
|
Binary file (6.24 kB)
|
|
|
chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/subsampling.cpython-311.pyc
DELETED
|
Binary file (18.9 kB)
|
|
|
chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/upsample_encoder.cpython-311.pyc
DELETED
|
Binary file (15.6 kB)
|
|
|
chatterbox/src/chatterbox/models/s3gen/utils/__pycache__/class_utils.cpython-311.pyc
DELETED
|
Binary file (1.93 kB)
|
|
|
chatterbox/src/chatterbox/models/s3gen/utils/__pycache__/mask.cpython-311.pyc
DELETED
|
Binary file (6.25 kB)
|
|
|
chatterbox/src/chatterbox/models/s3gen/utils/__pycache__/mel.cpython-311.pyc
DELETED
|
Binary file (4.05 kB)
|
|
|
chatterbox/src/chatterbox/models/s3gen/utils/mel.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
| 1 |
"""mel-spectrogram extraction in Matcha-TTS"""
|
|
|
|
| 2 |
from librosa.filters import mel as librosa_mel_fn
|
| 3 |
import torch
|
| 4 |
import numpy as np
|
| 5 |
|
|
|
|
|
|
|
| 6 |
|
| 7 |
# NOTE: they decalred these global vars
|
| 8 |
mel_basis = {}
|
|
@@ -42,10 +45,11 @@ def mel_spectrogram(y, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=48
|
|
| 42 |
if len(y.shape) == 1:
|
| 43 |
y = y[None, ]
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
| 49 |
|
| 50 |
global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
|
| 51 |
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
|
|
|
|
| 1 |
"""mel-spectrogram extraction in Matcha-TTS"""
|
| 2 |
+
import logging
|
| 3 |
from librosa.filters import mel as librosa_mel_fn
|
| 4 |
import torch
|
| 5 |
import numpy as np
|
| 6 |
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
|
| 10 |
# NOTE: they decalred these global vars
|
| 11 |
mel_basis = {}
|
|
|
|
| 45 |
if len(y.shape) == 1:
|
| 46 |
y = y[None, ]
|
| 47 |
|
| 48 |
+
# Debug: Check for audio clipping (values outside [-1.0, 1.0] range)
|
| 49 |
+
min_val = torch.min(y)
|
| 50 |
+
max_val = torch.max(y)
|
| 51 |
+
if min_val < -1.0 or max_val > 1.0:
|
| 52 |
+
logger.warning(f"Audio values outside normalized range: min={min_val.item():.4f}, max={max_val.item():.4f}")
|
| 53 |
|
| 54 |
global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
|
| 55 |
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
|
chatterbox/src/chatterbox/models/s3tokenizer/__pycache__/__init__.cpython-311.pyc
DELETED
|
Binary file (1.37 kB)
|
|
|
chatterbox/src/chatterbox/models/s3tokenizer/__pycache__/s3tokenizer.cpython-311.pyc
DELETED
|
Binary file (7.94 kB)
|
|
|
chatterbox/src/chatterbox/models/t3/__pycache__/__init__.cpython-311.pyc
DELETED
|
Binary file (218 Bytes)
|
|
|
chatterbox/src/chatterbox/models/t3/__pycache__/llama_configs.cpython-311.pyc
DELETED
|
Binary file (1.34 kB)
|
|
|
chatterbox/src/chatterbox/models/t3/__pycache__/t3.cpython-311.pyc
DELETED
|
Binary file (15.8 kB)
|
|
|
chatterbox/src/chatterbox/models/t3/inference/__pycache__/alignment_stream_analyzer.cpython-311.pyc
DELETED
|
Binary file (7.08 kB)
|
|
|
chatterbox/src/chatterbox/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc
DELETED
|
Binary file (4.65 kB)
|
|
|
chatterbox/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py
CHANGED
|
@@ -10,6 +10,9 @@ from types import MethodType
|
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
|
|
|
|
|
|
|
|
|
|
| 13 |
@dataclass
|
| 14 |
class AlignmentAnalysisResult:
|
| 15 |
# was this frame detected as being part of a noisy beginning chunk with potential hallucinations?
|
|
@@ -49,21 +52,22 @@ class AlignmentStreamAnalyzer:
|
|
| 49 |
|
| 50 |
self.complete = False
|
| 51 |
self.completed_at = None
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
# Using `output_attentions=True` is incompatible with optimized attention kernels, so
|
| 54 |
# using it for all layers slows things down too much. We can apply it to just one layer
|
| 55 |
# by intercepting the kwargs and adding a forward hook (credit: jrm)
|
| 56 |
-
self.
|
| 57 |
-
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
def _add_attention_spy(self, tfmr,
|
| 60 |
"""
|
| 61 |
Adds a forward hook to a specific attention layer to collect outputs.
|
| 62 |
-
Using `output_attentions=True` is incompatible with optimized attention kernels, so
|
| 63 |
-
using it for all layers slows things down too much.
|
| 64 |
-
(credit: jrm)
|
| 65 |
"""
|
| 66 |
-
|
| 67 |
def attention_forward_hook(module, input, output):
|
| 68 |
"""
|
| 69 |
See `LlamaAttention.forward`; the output is a 3-tuple: `attn_output, attn_weights, past_key_value`.
|
|
@@ -71,27 +75,23 @@ class AlignmentStreamAnalyzer:
|
|
| 71 |
- When `output_attentions=True`, `LlamaSdpaAttention.forward` calls `LlamaAttention.forward`.
|
| 72 |
- `attn_output` has shape [B, H, T0, T0] for the 0th entry, and [B, H, 1, T0+i] for the rest i-th.
|
| 73 |
"""
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
# TODO: how to unpatch it?
|
| 87 |
-
target_layer.forward = MethodType(patched_forward, target_layer)
|
| 88 |
-
|
| 89 |
-
def step(self, logits):
|
| 90 |
"""
|
| 91 |
Emits an AlignmentAnalysisResult into the output queue, and potentially modifies the logits to force an EOS.
|
| 92 |
"""
|
| 93 |
# extract approximate alignment matrix chunk (1 frame at a time after the first chunk)
|
| 94 |
-
aligned_attn = self.
|
| 95 |
i, j = self.text_tokens_slice
|
| 96 |
if self.curr_frame_pos == 0:
|
| 97 |
# first chunk has conditioning info, text tokens, and BOS token
|
|
@@ -133,22 +133,46 @@ class AlignmentStreamAnalyzer:
|
|
| 133 |
last_text_token_duration = A[15:, -3:].sum()
|
| 134 |
|
| 135 |
# Activations for the final token that last too long are likely hallucinations.
|
| 136 |
-
long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >=
|
| 137 |
|
| 138 |
# If there are activations in previous tokens after generation has completed, assume this is a repetition error.
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
# If a bad ending is detected, force emit EOS by modifying logits
|
| 142 |
# NOTE: this means logits may be inconsistent with latents!
|
| 143 |
-
if long_tail or
|
| 144 |
-
logger.
|
| 145 |
# (±2**15 is safe for all dtypes >= 16bit)
|
| 146 |
logits = -(2**15) * torch.ones_like(logits)
|
| 147 |
logits[..., self.eos_idx] = 2**15
|
| 148 |
|
| 149 |
-
# Suppress EoS to prevent early termination
|
| 150 |
-
if cur_text_posn < S - 3: # FIXME: arbitrary
|
| 151 |
-
logits[..., self.eos_idx] = -2**15
|
| 152 |
-
|
| 153 |
self.curr_frame_pos += 1
|
| 154 |
return logits
|
|
|
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
|
| 13 |
+
LLAMA_ALIGNED_HEADS = [(12, 15), (13, 11), (9, 2)]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
@dataclass
|
| 17 |
class AlignmentAnalysisResult:
|
| 18 |
# was this frame detected as being part of a noisy beginning chunk with potential hallucinations?
|
|
|
|
| 52 |
|
| 53 |
self.complete = False
|
| 54 |
self.completed_at = None
|
| 55 |
+
|
| 56 |
+
# Track generated tokens for repetition detection
|
| 57 |
+
self.generated_tokens = []
|
| 58 |
|
| 59 |
# Using `output_attentions=True` is incompatible with optimized attention kernels, so
|
| 60 |
# using it for all layers slows things down too much. We can apply it to just one layer
|
| 61 |
# by intercepting the kwargs and adding a forward hook (credit: jrm)
|
| 62 |
+
self.last_aligned_attns = []
|
| 63 |
+
for i, (layer_idx, head_idx) in enumerate(LLAMA_ALIGNED_HEADS):
|
| 64 |
+
self.last_aligned_attns += [None]
|
| 65 |
+
self._add_attention_spy(tfmr, i, layer_idx, head_idx)
|
| 66 |
|
| 67 |
+
def _add_attention_spy(self, tfmr, buffer_idx, layer_idx, head_idx):
|
| 68 |
"""
|
| 69 |
Adds a forward hook to a specific attention layer to collect outputs.
|
|
|
|
|
|
|
|
|
|
| 70 |
"""
|
|
|
|
| 71 |
def attention_forward_hook(module, input, output):
|
| 72 |
"""
|
| 73 |
See `LlamaAttention.forward`; the output is a 3-tuple: `attn_output, attn_weights, past_key_value`.
|
|
|
|
| 75 |
- When `output_attentions=True`, `LlamaSdpaAttention.forward` calls `LlamaAttention.forward`.
|
| 76 |
- `attn_output` has shape [B, H, T0, T0] for the 0th entry, and [B, H, 1, T0+i] for the rest i-th.
|
| 77 |
"""
|
| 78 |
+
if isinstance(output, tuple) and len(output) > 1 and output[1] is not None:
|
| 79 |
+
step_attention = output[1].cpu() # (B, n_heads, T0, Ti)
|
| 80 |
+
self.last_aligned_attns[buffer_idx] = step_attention[0, head_idx] # (T0, Ti)
|
| 81 |
+
|
| 82 |
+
target_layer = tfmr.layers[layer_idx].self_attn
|
| 83 |
+
# Register hook and store the handle
|
| 84 |
+
target_layer.register_forward_hook(attention_forward_hook)
|
| 85 |
+
if hasattr(tfmr, 'config') and hasattr(tfmr.config, 'output_attentions'):
|
| 86 |
+
self.original_output_attentions = tfmr.config.output_attentions
|
| 87 |
+
tfmr.config.output_attentions = True
|
| 88 |
+
|
| 89 |
+
def step(self, logits, next_token=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
"""
|
| 91 |
Emits an AlignmentAnalysisResult into the output queue, and potentially modifies the logits to force an EOS.
|
| 92 |
"""
|
| 93 |
# extract approximate alignment matrix chunk (1 frame at a time after the first chunk)
|
| 94 |
+
aligned_attn = torch.stack(self.last_aligned_attns).mean(dim=0) # (N, N)
|
| 95 |
i, j = self.text_tokens_slice
|
| 96 |
if self.curr_frame_pos == 0:
|
| 97 |
# first chunk has conditioning info, text tokens, and BOS token
|
|
|
|
| 133 |
last_text_token_duration = A[15:, -3:].sum()
|
| 134 |
|
| 135 |
# Activations for the final token that last too long are likely hallucinations.
|
| 136 |
+
long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 5) # 200ms
|
| 137 |
|
| 138 |
# If there are activations in previous tokens after generation has completed, assume this is a repetition error.
|
| 139 |
+
alignment_repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5)
|
| 140 |
+
|
| 141 |
+
# Track generated tokens for repetition detection
|
| 142 |
+
if next_token is not None:
|
| 143 |
+
# Convert tensor to scalar if needed
|
| 144 |
+
if isinstance(next_token, torch.Tensor):
|
| 145 |
+
token_id = next_token.item() if next_token.numel() == 1 else next_token.view(-1)[0].item()
|
| 146 |
+
else:
|
| 147 |
+
token_id = next_token
|
| 148 |
+
self.generated_tokens.append(token_id)
|
| 149 |
+
|
| 150 |
+
# Keep only last 8 tokens to prevent memory issues
|
| 151 |
+
if len(self.generated_tokens) > 8:
|
| 152 |
+
self.generated_tokens = self.generated_tokens[-8:]
|
| 153 |
+
|
| 154 |
+
# Check for excessive token repetition (3x same token in a row)
|
| 155 |
+
token_repetition = (
|
| 156 |
+
# self.complete and
|
| 157 |
+
len(self.generated_tokens) >= 3 and
|
| 158 |
+
len(set(self.generated_tokens[-2:])) == 1
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
if token_repetition:
|
| 162 |
+
repeated_token = self.generated_tokens[-1]
|
| 163 |
+
logger.warning(f"🚨 Detected 2x repetition of token {repeated_token}")
|
| 164 |
+
|
| 165 |
+
# Suppress EoS to prevent early termination
|
| 166 |
+
if cur_text_posn < S - 3 and S > 5: # Only suppress if text is longer than 5 tokens
|
| 167 |
+
logits[..., self.eos_idx] = -2**15
|
| 168 |
|
| 169 |
# If a bad ending is detected, force emit EOS by modifying logits
|
| 170 |
# NOTE: this means logits may be inconsistent with latents!
|
| 171 |
+
if long_tail or alignment_repetition or token_repetition:
|
| 172 |
+
logger.warning(f"forcing EOS token, {long_tail=}, {alignment_repetition=}, {token_repetition=}")
|
| 173 |
# (±2**15 is safe for all dtypes >= 16bit)
|
| 174 |
logits = -(2**15) * torch.ones_like(logits)
|
| 175 |
logits[..., self.eos_idx] = 2**15
|
| 176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
self.curr_frame_pos += 1
|
| 178 |
return logits
|
chatterbox/src/chatterbox/models/t3/modules/__pycache__/cond_enc.cpython-311.pyc
DELETED
|
Binary file (5.37 kB)
|
|
|
chatterbox/src/chatterbox/models/t3/modules/__pycache__/learned_pos_emb.cpython-311.pyc
DELETED
|
Binary file (2.54 kB)
|
|
|
chatterbox/src/chatterbox/models/t3/modules/__pycache__/perceiver.cpython-311.pyc
DELETED
|
Binary file (12.6 kB)
|
|
|
chatterbox/src/chatterbox/models/t3/modules/__pycache__/t3_config.cpython-311.pyc
DELETED
|
Binary file (1.27 kB)
|
|
|
chatterbox/src/chatterbox/models/t3/modules/t3_config.py
CHANGED
|
@@ -2,26 +2,40 @@ from ..llama_configs import LLAMA_CONFIGS
|
|
| 2 |
|
| 3 |
|
| 4 |
class T3Config:
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
|
|
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
emotion_adv = True
|
| 24 |
|
| 25 |
@property
|
| 26 |
def n_channels(self):
|
| 27 |
return LLAMA_CONFIGS[self.llama_config_name]["hidden_size"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
class T3Config:
|
| 5 |
+
def __init__(self, text_tokens_dict_size=704):
|
| 6 |
+
self.start_text_token = 255
|
| 7 |
+
self.stop_text_token = 0
|
| 8 |
+
self.text_tokens_dict_size = text_tokens_dict_size
|
| 9 |
+
self.max_text_tokens = 2048
|
| 10 |
|
| 11 |
+
self.start_speech_token = 6561
|
| 12 |
+
self.stop_speech_token = 6562
|
| 13 |
+
self.speech_tokens_dict_size = 8194
|
| 14 |
+
self.max_speech_tokens = 4096
|
| 15 |
|
| 16 |
+
self.llama_config_name = "Llama_520M"
|
| 17 |
+
self.input_pos_emb = "learned"
|
| 18 |
+
self.speech_cond_prompt_len = 150
|
| 19 |
|
| 20 |
+
self.encoder_type = "voice_encoder"
|
| 21 |
+
self.speaker_embed_size = 256
|
| 22 |
+
self.use_perceiver_resampler = True
|
| 23 |
+
self.emotion_adv = True
|
|
|
|
| 24 |
|
| 25 |
@property
|
| 26 |
def n_channels(self):
|
| 27 |
return LLAMA_CONFIGS[self.llama_config_name]["hidden_size"]
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
def is_multilingual(self):
|
| 31 |
+
return self.text_tokens_dict_size == 2454
|
| 32 |
+
|
| 33 |
+
@classmethod
|
| 34 |
+
def english_only(cls):
|
| 35 |
+
"""Create configuration for English-only TTS model."""
|
| 36 |
+
return cls(text_tokens_dict_size=704)
|
| 37 |
+
|
| 38 |
+
@classmethod
|
| 39 |
+
def multilingual(cls):
|
| 40 |
+
"""Create configuration for multilingual TTS model."""
|
| 41 |
+
return cls(text_tokens_dict_size=2454)
|
chatterbox/src/chatterbox/models/t3/t3.py
CHANGED
|
@@ -3,12 +3,14 @@
|
|
| 3 |
import logging
|
| 4 |
from typing import Union, Optional, List
|
| 5 |
|
|
|
|
|
|
|
| 6 |
from tqdm import tqdm
|
| 7 |
import torch
|
| 8 |
import torch.nn.functional as F
|
| 9 |
from torch import nn, Tensor
|
| 10 |
from transformers import LlamaModel, LlamaConfig
|
| 11 |
-
from transformers.generation.logits_process import TopPLogitsWarper, RepetitionPenaltyLogitsProcessor
|
| 12 |
|
| 13 |
from .modules.learned_pos_emb import LearnedPositionEmbeddings
|
| 14 |
|
|
@@ -17,17 +19,12 @@ from .modules.t3_config import T3Config
|
|
| 17 |
from .llama_configs import LLAMA_CONFIGS
|
| 18 |
from .inference.t3_hf_backend import T3HuggingfaceBackend
|
| 19 |
from .inference.alignment_stream_analyzer import AlignmentStreamAnalyzer
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
logger = logging.getLogger(__name__)
|
| 23 |
|
| 24 |
|
| 25 |
-
class AttrDict(dict):
|
| 26 |
-
def __init__(self, *args, **kwargs):
|
| 27 |
-
super(AttrDict, self).__init__(*args, **kwargs)
|
| 28 |
-
self.__dict__ = self
|
| 29 |
-
|
| 30 |
-
|
| 31 |
def _ensure_BOT_EOT(text_tokens: Tensor, hp):
|
| 32 |
B = text_tokens.size(0)
|
| 33 |
assert (text_tokens == hp.start_text_token).int().sum() >= B, "missing start_text_token"
|
|
@@ -44,7 +41,9 @@ class T3(nn.Module):
|
|
| 44 |
different PE embedding space for speech.
|
| 45 |
"""
|
| 46 |
|
| 47 |
-
def __init__(self, hp=
|
|
|
|
|
|
|
| 48 |
super().__init__()
|
| 49 |
self.hp = hp
|
| 50 |
self.cfg = LlamaConfig(**LLAMA_CONFIGS[hp.llama_config_name])
|
|
@@ -89,11 +88,13 @@ class T3(nn.Module):
|
|
| 89 |
t3_cond: T3Cond,
|
| 90 |
text_tokens: torch.LongTensor,
|
| 91 |
speech_tokens: torch.LongTensor,
|
|
|
|
| 92 |
):
|
| 93 |
# prepare input embeddings (skip backbone tranformer embeddings)
|
| 94 |
cond_emb = self.prepare_conditioning(t3_cond) # (B, len_cond, dim)
|
| 95 |
text_emb = self.text_emb(text_tokens) # (B, len_text, dim)
|
| 96 |
-
|
|
|
|
| 97 |
|
| 98 |
speech_emb = self.speech_emb(speech_tokens) # (B, len_speech, dim)
|
| 99 |
if self.hp.input_pos_emb == "learned":
|
|
@@ -221,10 +222,11 @@ class T3(nn.Module):
|
|
| 221 |
stop_on_eos=True,
|
| 222 |
do_sample=True,
|
| 223 |
temperature=0.8,
|
| 224 |
-
top_p=0.
|
|
|
|
| 225 |
length_penalty=1.0,
|
| 226 |
-
repetition_penalty=2
|
| 227 |
-
cfg_weight=0,
|
| 228 |
):
|
| 229 |
"""
|
| 230 |
Args:
|
|
@@ -244,6 +246,7 @@ class T3(nn.Module):
|
|
| 244 |
t3_cond=t3_cond,
|
| 245 |
text_tokens=text_tokens,
|
| 246 |
speech_tokens=initial_speech_tokens,
|
|
|
|
| 247 |
)
|
| 248 |
|
| 249 |
# In order to use the standard HF generate method, we need to extend some methods to inject our custom logic
|
|
@@ -254,19 +257,24 @@ class T3(nn.Module):
|
|
| 254 |
# TODO? synchronize the expensive compile function
|
| 255 |
# with self.compile_lock:
|
| 256 |
if not self.compiled:
|
| 257 |
-
#
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
patched_model = T3HuggingfaceBackend(
|
| 265 |
config=self.cfg,
|
| 266 |
llama=self.tfmr,
|
| 267 |
speech_enc=self.speech_emb,
|
| 268 |
speech_head=self.speech_head,
|
| 269 |
-
|
| 270 |
)
|
| 271 |
self.patched_model = patched_model
|
| 272 |
self.compiled = True
|
|
@@ -281,7 +289,7 @@ class T3(nn.Module):
|
|
| 281 |
# max_new_tokens=max_new_tokens or self.hp.max_speech_tokens,
|
| 282 |
# num_return_sequences=num_return_sequences,
|
| 283 |
# temperature=temperature,
|
| 284 |
-
#
|
| 285 |
# length_penalty=length_penalty,
|
| 286 |
# repetition_penalty=repetition_penalty,
|
| 287 |
# do_sample=do_sample,
|
|
@@ -306,7 +314,9 @@ class T3(nn.Module):
|
|
| 306 |
|
| 307 |
# Instantiate the logits processors.
|
| 308 |
top_p_warper = TopPLogitsWarper(top_p=top_p)
|
| 309 |
-
|
|
|
|
|
|
|
| 310 |
|
| 311 |
# ---- Initial Forward Pass (no kv_cache yet) ----
|
| 312 |
output = self.patched_model(
|
|
@@ -322,21 +332,32 @@ class T3(nn.Module):
|
|
| 322 |
|
| 323 |
# ---- Generation Loop using kv_cache ----
|
| 324 |
for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True):
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
logits =
|
| 331 |
-
|
| 332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
# Apply temperature scaling.
|
| 334 |
if temperature != 1.0:
|
| 335 |
logits = logits / temperature
|
| 336 |
-
|
| 337 |
-
# Apply
|
| 338 |
-
logits =
|
| 339 |
-
logits = top_p_warper(
|
| 340 |
|
| 341 |
# Convert logits to probabilities and sample the next token.
|
| 342 |
probs = torch.softmax(logits, dim=-1)
|
|
@@ -347,6 +368,7 @@ class T3(nn.Module):
|
|
| 347 |
|
| 348 |
# Check for EOS token.
|
| 349 |
if next_token.view(-1) == self.hp.stop_speech_token:
|
|
|
|
| 350 |
break
|
| 351 |
|
| 352 |
# Get embedding for the new token.
|
|
|
|
| 3 |
import logging
|
| 4 |
from typing import Union, Optional, List
|
| 5 |
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
from tqdm import tqdm
|
| 9 |
import torch
|
| 10 |
import torch.nn.functional as F
|
| 11 |
from torch import nn, Tensor
|
| 12 |
from transformers import LlamaModel, LlamaConfig
|
| 13 |
+
from transformers.generation.logits_process import TopPLogitsWarper, RepetitionPenaltyLogitsProcessor, MinPLogitsWarper
|
| 14 |
|
| 15 |
from .modules.learned_pos_emb import LearnedPositionEmbeddings
|
| 16 |
|
|
|
|
| 19 |
from .llama_configs import LLAMA_CONFIGS
|
| 20 |
from .inference.t3_hf_backend import T3HuggingfaceBackend
|
| 21 |
from .inference.alignment_stream_analyzer import AlignmentStreamAnalyzer
|
| 22 |
+
from ..utils import AttrDict
|
| 23 |
|
| 24 |
|
| 25 |
logger = logging.getLogger(__name__)
|
| 26 |
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
def _ensure_BOT_EOT(text_tokens: Tensor, hp):
|
| 29 |
B = text_tokens.size(0)
|
| 30 |
assert (text_tokens == hp.start_text_token).int().sum() >= B, "missing start_text_token"
|
|
|
|
| 41 |
different PE embedding space for speech.
|
| 42 |
"""
|
| 43 |
|
| 44 |
+
def __init__(self, hp=None):
|
| 45 |
+
if hp is None:
|
| 46 |
+
hp = T3Config.english_only() # Default to English-only config for backward compatibility
|
| 47 |
super().__init__()
|
| 48 |
self.hp = hp
|
| 49 |
self.cfg = LlamaConfig(**LLAMA_CONFIGS[hp.llama_config_name])
|
|
|
|
| 88 |
t3_cond: T3Cond,
|
| 89 |
text_tokens: torch.LongTensor,
|
| 90 |
speech_tokens: torch.LongTensor,
|
| 91 |
+
cfg_weight: float = 0.0,
|
| 92 |
):
|
| 93 |
# prepare input embeddings (skip backbone tranformer embeddings)
|
| 94 |
cond_emb = self.prepare_conditioning(t3_cond) # (B, len_cond, dim)
|
| 95 |
text_emb = self.text_emb(text_tokens) # (B, len_text, dim)
|
| 96 |
+
if cfg_weight > 0.0:
|
| 97 |
+
text_emb[1].zero_() # CFG uncond
|
| 98 |
|
| 99 |
speech_emb = self.speech_emb(speech_tokens) # (B, len_speech, dim)
|
| 100 |
if self.hp.input_pos_emb == "learned":
|
|
|
|
| 222 |
stop_on_eos=True,
|
| 223 |
do_sample=True,
|
| 224 |
temperature=0.8,
|
| 225 |
+
top_p=0.95,
|
| 226 |
+
min_p=0.05,
|
| 227 |
length_penalty=1.0,
|
| 228 |
+
repetition_penalty=1.2,
|
| 229 |
+
cfg_weight=0.5,
|
| 230 |
):
|
| 231 |
"""
|
| 232 |
Args:
|
|
|
|
| 246 |
t3_cond=t3_cond,
|
| 247 |
text_tokens=text_tokens,
|
| 248 |
speech_tokens=initial_speech_tokens,
|
| 249 |
+
cfg_weight=cfg_weight,
|
| 250 |
)
|
| 251 |
|
| 252 |
# In order to use the standard HF generate method, we need to extend some methods to inject our custom logic
|
|
|
|
| 257 |
# TODO? synchronize the expensive compile function
|
| 258 |
# with self.compile_lock:
|
| 259 |
if not self.compiled:
|
| 260 |
+
# Default to None for English models, only create for multilingual
|
| 261 |
+
alignment_stream_analyzer = None
|
| 262 |
+
if self.hp.is_multilingual:
|
| 263 |
+
alignment_stream_analyzer = AlignmentStreamAnalyzer(
|
| 264 |
+
self.tfmr,
|
| 265 |
+
None,
|
| 266 |
+
text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
|
| 267 |
+
alignment_layer_idx=9, # TODO: hparam or something?
|
| 268 |
+
eos_idx=self.hp.stop_speech_token,
|
| 269 |
+
)
|
| 270 |
+
assert alignment_stream_analyzer.eos_idx == self.hp.stop_speech_token
|
| 271 |
+
|
| 272 |
patched_model = T3HuggingfaceBackend(
|
| 273 |
config=self.cfg,
|
| 274 |
llama=self.tfmr,
|
| 275 |
speech_enc=self.speech_emb,
|
| 276 |
speech_head=self.speech_head,
|
| 277 |
+
alignment_stream_analyzer=alignment_stream_analyzer,
|
| 278 |
)
|
| 279 |
self.patched_model = patched_model
|
| 280 |
self.compiled = True
|
|
|
|
| 289 |
# max_new_tokens=max_new_tokens or self.hp.max_speech_tokens,
|
| 290 |
# num_return_sequences=num_return_sequences,
|
| 291 |
# temperature=temperature,
|
| 292 |
+
# min_p=min_p,
|
| 293 |
# length_penalty=length_penalty,
|
| 294 |
# repetition_penalty=repetition_penalty,
|
| 295 |
# do_sample=do_sample,
|
|
|
|
| 314 |
|
| 315 |
# Instantiate the logits processors.
|
| 316 |
top_p_warper = TopPLogitsWarper(top_p=top_p)
|
| 317 |
+
min_p_warper = MinPLogitsWarper(min_p=min_p)
|
| 318 |
+
top_p_warper = TopPLogitsWarper(top_p=top_p)
|
| 319 |
+
repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty))
|
| 320 |
|
| 321 |
# ---- Initial Forward Pass (no kv_cache yet) ----
|
| 322 |
output = self.patched_model(
|
|
|
|
| 332 |
|
| 333 |
# ---- Generation Loop using kv_cache ----
|
| 334 |
for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True):
|
| 335 |
+
logits_step = output.logits[:, -1, :]
|
| 336 |
+
# CFG combine → (1, V)
|
| 337 |
+
cond = logits_step[0:1, :]
|
| 338 |
+
uncond = logits_step[1:2, :]
|
| 339 |
+
cfg = torch.as_tensor(cfg_weight, device=cond.device, dtype=cond.dtype)
|
| 340 |
+
logits = cond + cfg * (cond - uncond)
|
| 341 |
+
|
| 342 |
+
# Apply alignment stream analyzer integrity checks
|
| 343 |
+
if self.patched_model.alignment_stream_analyzer is not None:
|
| 344 |
+
if logits.dim() == 1: # guard in case something upstream squeezed
|
| 345 |
+
logits = logits.unsqueeze(0) # (1, V)
|
| 346 |
+
# Pass the last generated token for repetition tracking
|
| 347 |
+
last_token = generated_ids[0, -1].item() if len(generated_ids[0]) > 0 else None
|
| 348 |
+
logits = self.patched_model.alignment_stream_analyzer.step(logits, next_token=last_token) # (1, V)
|
| 349 |
+
|
| 350 |
+
# Apply repetition penalty
|
| 351 |
+
ids_for_proc = generated_ids[:1, ...] # batch = 1
|
| 352 |
+
logits = repetition_penalty_processor(ids_for_proc, logits) # expects (B,V)
|
| 353 |
+
|
| 354 |
# Apply temperature scaling.
|
| 355 |
if temperature != 1.0:
|
| 356 |
logits = logits / temperature
|
| 357 |
+
|
| 358 |
+
# Apply min_p and top_p filtering
|
| 359 |
+
logits = min_p_warper(ids_for_proc, logits)
|
| 360 |
+
logits = top_p_warper(ids_for_proc, logits)
|
| 361 |
|
| 362 |
# Convert logits to probabilities and sample the next token.
|
| 363 |
probs = torch.softmax(logits, dim=-1)
|
|
|
|
| 368 |
|
| 369 |
# Check for EOS token.
|
| 370 |
if next_token.view(-1) == self.hp.stop_speech_token:
|
| 371 |
+
logger.info(f"✅ EOS token detected! Stopping generation at step {i+1}")
|
| 372 |
break
|
| 373 |
|
| 374 |
# Get embedding for the new token.
|