diff --git a/README.md b/README.md index f670e6c7ca6d38b75de49232eedfe87e08f138c4..4fbdaa0c24f6f25a6bce996a3cbe2477c4dc6472 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ sdk_version: 5.29.0 app_file: app.py pinned: false short_description: Expressive Zeroshot TTS +python_version: 3.10 --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference \ No newline at end of file diff --git a/chatterbox/src/chatterbox/__init__.py b/chatterbox/src/chatterbox/__init__.py index c8aa565d6cf00b8eaf2b7896ea751bb8091fc77a..d8c1751b9cb9d33071f414d514a72389bf7b0bdc 100644 --- a/chatterbox/src/chatterbox/__init__.py +++ b/chatterbox/src/chatterbox/__init__.py @@ -1,2 +1,17 @@ +try: + from importlib.metadata import version, PackageNotFoundError + try: + __version__ = version("chatterbox-tts") + except PackageNotFoundError: + __version__ = "0.1.4" # Default fallback version +except ImportError: + from importlib_metadata import version, PackageNotFoundError # For Python <3.8 + try: + __version__ = version("chatterbox-tts") + except PackageNotFoundError: + __version__ = "0.1.4" + + from .tts import ChatterboxTTS from .vc import ChatterboxVC +from .mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES \ No newline at end of file diff --git a/chatterbox/src/chatterbox/__pycache__/__init__.cpython-311.pyc b/chatterbox/src/chatterbox/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index f83ebeb4d81a7ddafc45dc5a78c13fca1bd8e998..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/__pycache__/tts.cpython-311.pyc b/chatterbox/src/chatterbox/__pycache__/tts.cpython-311.pyc deleted file mode 100644 index 41e40d2b049af1fbc940bfd0f710e5c2365e9897..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/__pycache__/tts.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/__pycache__/utils.cpython-311.pyc b/chatterbox/src/chatterbox/__pycache__/utils.cpython-311.pyc deleted file mode 100644 index 07d1a9336614c485bfd9e5e1b992020cbf64e156..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/__pycache__/utils.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/__pycache__/vc.cpython-311.pyc b/chatterbox/src/chatterbox/__pycache__/vc.cpython-311.pyc deleted file mode 100644 index 3efca27acc1b5a881a57348304d70ccaf540ccd5..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/__pycache__/vc.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/__init__.py b/chatterbox/src/chatterbox/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/chatterbox/src/chatterbox/models/s3gen/__pycache__/__init__.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3gen/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 0440f658e117cddaa5002926037fbc3ff8aac5cb..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3gen/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3gen/__pycache__/const.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3gen/__pycache__/const.cpython-311.pyc deleted file mode 100644 index 23d28ef23a8e4dee953d76ab3d81c0a77ff65565..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3gen/__pycache__/const.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3gen/__pycache__/decoder.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3gen/__pycache__/decoder.cpython-311.pyc deleted file mode 100644 index 09efdf9864f984a380da312b6f9a370d24f899db..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3gen/__pycache__/decoder.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3gen/__pycache__/f0_predictor.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3gen/__pycache__/f0_predictor.cpython-311.pyc deleted file mode 100644 index a525d29d9797bc22e9116c9c0864a9e035cd92f3..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3gen/__pycache__/f0_predictor.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3gen/__pycache__/flow.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3gen/__pycache__/flow.cpython-311.pyc deleted file mode 100644 index 0bde78561834629249db2bd579c6296f231795cb..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3gen/__pycache__/flow.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3gen/__pycache__/flow_matching.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3gen/__pycache__/flow_matching.cpython-311.pyc deleted file mode 100644 index 6ec2c56309afcda69e64669e107cbc6b2cecf7a2..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3gen/__pycache__/flow_matching.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3gen/__pycache__/hifigan.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3gen/__pycache__/hifigan.cpython-311.pyc deleted file mode 100644 index 373e419165ad8449b4a81d7097e2afa23acd17a0..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3gen/__pycache__/hifigan.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3gen/__pycache__/s3gen.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3gen/__pycache__/s3gen.cpython-311.pyc deleted file mode 100644 index 36822202135e6b9d334d7691413d9de9d87d9b1c..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3gen/__pycache__/s3gen.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3gen/__pycache__/xvector.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3gen/__pycache__/xvector.cpython-311.pyc deleted file mode 100644 index e989a9dd0f760e44c040e54daf9c448e405b6421..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3gen/__pycache__/xvector.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3gen/configs.py b/chatterbox/src/chatterbox/models/s3gen/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..b09b2e52c2873095c81a0d1d7cb97130cefdc7f5 --- /dev/null +++ b/chatterbox/src/chatterbox/models/s3gen/configs.py @@ -0,0 +1,10 @@ +from ..utils import AttrDict + +CFM_PARAMS = AttrDict({ + "sigma_min": 1e-06, + "solver": "euler", + "t_scheduler": "cosine", + "training_cfg_rate": 0.2, + "inference_cfg_rate": 0.7, + "reg_loss_type": "l1" +}) diff --git a/chatterbox/src/chatterbox/models/s3gen/flow.py b/chatterbox/src/chatterbox/models/s3gen/flow.py index a460ddef5db032967e849a2c4e134fcdf58d622d..ad19cfa3ef00fc60c2c7371a152dc8a368cfe7d0 100644 --- a/chatterbox/src/chatterbox/models/s3gen/flow.py +++ b/chatterbox/src/chatterbox/models/s3gen/flow.py @@ -14,32 +14,54 @@ import logging import random from typing import Dict, Optional + +logger = logging.getLogger(__name__) import torch import torch.nn as nn from torch.nn import functional as F -from omegaconf import DictConfig from .utils.mask import make_pad_mask +from .configs import CFM_PARAMS class MaskedDiffWithXvec(torch.nn.Module): - def __init__(self, - input_size: int = 512, - output_size: int = 80, - spk_embed_dim: int = 192, - output_type: str = "mel", - vocab_size: int = 4096, - input_frame_rate: int = 50, - only_mask_loss: bool = True, - encoder: torch.nn.Module = None, - length_regulator: torch.nn.Module = None, - decoder: torch.nn.Module = None, - decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, - 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', - 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), - 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, - 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}, - mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, - 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}): + def __init__( + self, + input_size: int = 512, + output_size: int = 80, + spk_embed_dim: int = 192, + output_type: str = "mel", + vocab_size: int = 4096, + input_frame_rate: int = 50, + only_mask_loss: bool = True, + encoder: torch.nn.Module = None, + length_regulator: torch.nn.Module = None, + decoder: torch.nn.Module = None, + decoder_conf: Dict = { + 'in_channels': 240, + 'out_channel': 80, + 'spk_emb_dim': 80, + 'n_spks': 1, + 'cfm_params': CFM_PARAMS, + 'decoder_params': { + 'channels': [256, 256], + 'dropout': 0.0, + 'attention_head_dim': 64, + 'n_blocks': 4, + 'num_mid_blocks': 12, + 'num_heads': 8, + 'act_fn': 'gelu', + } + }, + mel_feat_conf: Dict = { + 'n_fft': 1024, + 'num_mels': 80, + 'sampling_rate': 22050, + 'hop_size': 256, + 'win_size': 1024, + 'fmin': 0, + 'fmax': 8000 + } + ): super().__init__() self.input_size = input_size self.output_size = output_size @@ -74,7 +96,7 @@ class MaskedDiffWithXvec(torch.nn.Module): # concat text and prompt_text mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device) - token = self.input_embedding(torch.clamp(token, min=0)) * mask + token = self.input_embedding(torch.clamp(token, min=0, max=self.input_embedding.num_embeddings-1)) * mask # text encode h, h_lengths = self.encoder(token, token_len) @@ -124,7 +146,13 @@ class MaskedDiffWithXvec(torch.nn.Module): token_len1, token_len2 = prompt_token.shape[1], token.shape[1] token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) - token = self.input_embedding(torch.clamp(token, min=0)) * mask + + # Check for out-of-bounds token IDs + vocab_size = self.input_embedding.num_embeddings + if token.max() >= vocab_size or token.min() < 0: + logging.warning(f"S3Gen: Token IDs out of bounds: min={token.min().item()}, max={token.max().item()}, vocab_size={vocab_size}") + + token = self.input_embedding(torch.clamp(token, min=0, max=vocab_size-1)) * mask # text encode h, h_lengths = self.encoder(token, token_len) @@ -153,25 +181,45 @@ class MaskedDiffWithXvec(torch.nn.Module): class CausalMaskedDiffWithXvec(torch.nn.Module): - def __init__(self, - input_size: int = 512, - output_size: int = 80, - spk_embed_dim: int = 192, - output_type: str = "mel", - vocab_size: int = 6561, - input_frame_rate: int = 25, - only_mask_loss: bool = True, - token_mel_ratio: int = 2, - pre_lookahead_len: int = 3, - encoder: torch.nn.Module = None, - decoder: torch.nn.Module = None, - decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, - 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', - 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), - 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, - 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}, - mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, - 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}): + def __init__( + self, + input_size: int = 512, + output_size: int = 80, + spk_embed_dim: int = 192, + output_type: str = "mel", + vocab_size: int = 6561, + input_frame_rate: int = 25, + only_mask_loss: bool = True, + token_mel_ratio: int = 2, + pre_lookahead_len: int = 3, + encoder: torch.nn.Module = None, + decoder: torch.nn.Module = None, + decoder_conf: Dict = { + 'in_channels': 240, + 'out_channel': 80, + 'spk_emb_dim': 80, + 'n_spks': 1, + 'cfm_params': CFM_PARAMS, + 'decoder_params': { + 'channels': [256, 256], + 'dropout': 0.0, + 'attention_head_dim': 64, + 'n_blocks': 4, + 'num_mid_blocks': 12, + 'num_heads': 8, + 'act_fn': 'gelu', + } + }, + mel_feat_conf: Dict = { + 'n_fft': 1024, + 'num_mels': 80, + 'sampling_rate': 22050, + 'hop_size': 256, + 'win_size': 1024, + 'fmin': 0, + 'fmax': 8000 + } + ): super().__init__() self.input_size = input_size self.output_size = output_size @@ -215,7 +263,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): # concat text and prompt_text token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) - token = self.input_embedding(torch.clamp(token, min=0)) * mask + token = self.input_embedding(torch.clamp(token, min=0, max=self.input_embedding.num_embeddings-1)) * mask # text encode h, h_lengths = self.encoder(token, token_len) diff --git a/chatterbox/src/chatterbox/models/s3gen/flow_matching.py b/chatterbox/src/chatterbox/models/s3gen/flow_matching.py index 8307e3c0d6120a81b6ff414fafa30e9fc63d015c..ecd69fa485d93fcbcc5edf8e400d053b0a6e9658 100644 --- a/chatterbox/src/chatterbox/models/s3gen/flow_matching.py +++ b/chatterbox/src/chatterbox/models/s3gen/flow_matching.py @@ -15,17 +15,7 @@ import threading import torch import torch.nn.functional as F from .matcha.flow_matching import BASECFM -from omegaconf import OmegaConf - - -CFM_PARAMS = OmegaConf.create({ - "sigma_min": 1e-06, - "solver": "euler", - "t_scheduler": "cosine", - "training_cfg_rate": 0.2, - "inference_cfg_rate": 0.7, - "reg_loss_type": "l1" -}) +from .configs import CFM_PARAMS class ConditionalCFM(BASECFM): diff --git a/chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/decoder.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/decoder.cpython-311.pyc deleted file mode 100644 index 38b6607329e7bad030998926218327187f1f05ac..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/decoder.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/flow_matching.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/flow_matching.cpython-311.pyc deleted file mode 100644 index 1880dce7dc46c4d150c17cfb3a638d0d5349313b..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/flow_matching.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/transformer.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/transformer.cpython-311.pyc deleted file mode 100644 index bd93032987a4f6492fd0e145d61b54b7e172afe4..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/transformer.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3gen/s3gen.py b/chatterbox/src/chatterbox/models/s3gen/s3gen.py index 97b7c0bd40ad6cd258ca3c4bd4ae752c78f28b19..b1cf05e62ace5cf5197e65bb85271b04ec9afbee 100644 --- a/chatterbox/src/chatterbox/models/s3gen/s3gen.py +++ b/chatterbox/src/chatterbox/models/s3gen/s3gen.py @@ -19,7 +19,6 @@ import torch import torchaudio as ta from functools import lru_cache from typing import Optional -from omegaconf import DictConfig from ..s3tokenizer import S3_SR, SPEECH_VOCAB_SIZE, S3Tokenizer from .const import S3GEN_SR @@ -31,6 +30,7 @@ from .hifigan import HiFTGenerator from .transformer.upsample_encoder import UpsampleConformerEncoder from .flow_matching import CausalConditionalCFM from .decoder import ConditionalDecoder +from .configs import CFM_PARAMS def drop_invalid_tokens(x): @@ -85,14 +85,7 @@ class S3Token2Mel(torch.nn.Module): num_heads=8, act_fn='gelu', ) - cfm_params = DictConfig({ - "sigma_min": 1e-06, - "solver": 'euler', - "t_scheduler": 'cosine', - "training_cfg_rate": 0.2, - "inference_cfg_rate": 0.7, - "reg_loss_type": 'l1', - }) + cfm_params = CFM_PARAMS decoder = CausalConditionalCFM( spk_emb_dim=80, cfm_params=cfm_params, diff --git a/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/__init__.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 2112d2f573014db0ea5ad9e716243b517ae029a4..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/activation.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/activation.cpython-311.pyc deleted file mode 100644 index ea5df6579aa5e27a90b13282143dc0fe9173fe25..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/activation.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/attention.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/attention.cpython-311.pyc deleted file mode 100644 index 3920627114182de7dfd509107f2f02fe2d38bad0..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/attention.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/convolution.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/convolution.cpython-311.pyc deleted file mode 100644 index 27b290d0a771814370b3f5e0689eab144960bc8e..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/convolution.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/embedding.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/embedding.cpython-311.pyc deleted file mode 100644 index 15c05f8a86df64b9027c566dfc2a9554030058ae..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/embedding.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/encoder_layer.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/encoder_layer.cpython-311.pyc deleted file mode 100644 index 80713f3cd8b4a26360ea75611e67e2bdb1b9153d..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/encoder_layer.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc deleted file mode 100644 index 5dde70c679cfe5f723ee1bbba148f86ca474d346..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/subsampling.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/subsampling.cpython-311.pyc deleted file mode 100644 index 62655a1834a7c42c135a0801d3e651389485a85a..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/subsampling.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/upsample_encoder.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/upsample_encoder.cpython-311.pyc deleted file mode 100644 index dbc32d3e18b8287082fcc53ef1afff8273084c54..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/upsample_encoder.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3gen/utils/__pycache__/class_utils.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3gen/utils/__pycache__/class_utils.cpython-311.pyc deleted file mode 100644 index 51c2e3c410214dee72ff90ff1088aed9db91d0a2..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3gen/utils/__pycache__/class_utils.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3gen/utils/__pycache__/mask.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3gen/utils/__pycache__/mask.cpython-311.pyc deleted file mode 100644 index 5fea7167260f26771101b982fa2757b66d978a06..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3gen/utils/__pycache__/mask.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3gen/utils/__pycache__/mel.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3gen/utils/__pycache__/mel.cpython-311.pyc deleted file mode 100644 index c960eefbe30b8ff718a09496220ea970411e5fe5..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3gen/utils/__pycache__/mel.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3gen/utils/mel.py b/chatterbox/src/chatterbox/models/s3gen/utils/mel.py index 5a9ff9d11d67e1d6a96dd97d45a02366a3bba300..907d2b5770d2690b3e53a05e1952e3848b96ee41 100644 --- a/chatterbox/src/chatterbox/models/s3gen/utils/mel.py +++ b/chatterbox/src/chatterbox/models/s3gen/utils/mel.py @@ -1,8 +1,11 @@ """mel-spectrogram extraction in Matcha-TTS""" +import logging from librosa.filters import mel as librosa_mel_fn import torch import numpy as np +logger = logging.getLogger(__name__) + # NOTE: they decalred these global vars mel_basis = {} @@ -42,10 +45,11 @@ def mel_spectrogram(y, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=48 if len(y.shape) == 1: y = y[None, ] - if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) - if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) + # Debug: Check for audio clipping (values outside [-1.0, 1.0] range) + min_val = torch.min(y) + max_val = torch.max(y) + if min_val < -1.0 or max_val > 1.0: + logger.warning(f"Audio values outside normalized range: min={min_val.item():.4f}, max={max_val.item():.4f}") global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned if f"{str(fmax)}_{str(y.device)}" not in mel_basis: diff --git a/chatterbox/src/chatterbox/models/s3tokenizer/__pycache__/__init__.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3tokenizer/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 676514e1ee4e272d497749d9dc3c2bbaf51f3d9b..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3tokenizer/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/s3tokenizer/__pycache__/s3tokenizer.cpython-311.pyc b/chatterbox/src/chatterbox/models/s3tokenizer/__pycache__/s3tokenizer.cpython-311.pyc deleted file mode 100644 index c374ddbcc7529c57930ed5f4f562049f6924dad8..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/s3tokenizer/__pycache__/s3tokenizer.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/t3/__pycache__/__init__.cpython-311.pyc b/chatterbox/src/chatterbox/models/t3/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 4cc167688883d9a0560e21a89dd39e4bb0657ed5..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/t3/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/t3/__pycache__/llama_configs.cpython-311.pyc b/chatterbox/src/chatterbox/models/t3/__pycache__/llama_configs.cpython-311.pyc deleted file mode 100644 index a27f7451f472dd6b147438f40ff196e1d481387e..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/t3/__pycache__/llama_configs.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/t3/__pycache__/t3.cpython-311.pyc b/chatterbox/src/chatterbox/models/t3/__pycache__/t3.cpython-311.pyc deleted file mode 100644 index 61842097945f681d59e0d36696163f38e5ef5ba0..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/t3/__pycache__/t3.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/t3/inference/__pycache__/alignment_stream_analyzer.cpython-311.pyc b/chatterbox/src/chatterbox/models/t3/inference/__pycache__/alignment_stream_analyzer.cpython-311.pyc deleted file mode 100644 index 268c28edf2581499f3f563f15f4267b5167d9521..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/t3/inference/__pycache__/alignment_stream_analyzer.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc b/chatterbox/src/chatterbox/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc deleted file mode 100644 index dd403a13db05ab87f6f48b840e42468dde36a692..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py b/chatterbox/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py index d3a144f0f7f0cdef4a7a4c049db3b5433744296e..255d50f03675b5405461c2793b111c3d474b43d6 100644 --- a/chatterbox/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py +++ b/chatterbox/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py @@ -10,6 +10,9 @@ from types import MethodType logger = logging.getLogger(__name__) +LLAMA_ALIGNED_HEADS = [(12, 15), (13, 11), (9, 2)] + + @dataclass class AlignmentAnalysisResult: # was this frame detected as being part of a noisy beginning chunk with potential hallucinations? @@ -49,21 +52,22 @@ class AlignmentStreamAnalyzer: self.complete = False self.completed_at = None + + # Track generated tokens for repetition detection + self.generated_tokens = [] # Using `output_attentions=True` is incompatible with optimized attention kernels, so # using it for all layers slows things down too much. We can apply it to just one layer # by intercepting the kwargs and adding a forward hook (credit: jrm) - self.last_aligned_attn = None - self._add_attention_spy(tfmr, alignment_layer_idx) + self.last_aligned_attns = [] + for i, (layer_idx, head_idx) in enumerate(LLAMA_ALIGNED_HEADS): + self.last_aligned_attns += [None] + self._add_attention_spy(tfmr, i, layer_idx, head_idx) - def _add_attention_spy(self, tfmr, alignment_layer_idx): + def _add_attention_spy(self, tfmr, buffer_idx, layer_idx, head_idx): """ Adds a forward hook to a specific attention layer to collect outputs. - Using `output_attentions=True` is incompatible with optimized attention kernels, so - using it for all layers slows things down too much. - (credit: jrm) """ - def attention_forward_hook(module, input, output): """ See `LlamaAttention.forward`; the output is a 3-tuple: `attn_output, attn_weights, past_key_value`. @@ -71,27 +75,23 @@ class AlignmentStreamAnalyzer: - When `output_attentions=True`, `LlamaSdpaAttention.forward` calls `LlamaAttention.forward`. - `attn_output` has shape [B, H, T0, T0] for the 0th entry, and [B, H, 1, T0+i] for the rest i-th. """ - step_attention = output[1].cpu() # (B, 16, N, N) - self.last_aligned_attn = step_attention[0].mean(0) # (N, N) - - target_layer = tfmr.layers[alignment_layer_idx].self_attn - hook_handle = target_layer.register_forward_hook(attention_forward_hook) - - # Backup original forward - original_forward = target_layer.forward - def patched_forward(self, *args, **kwargs): - kwargs['output_attentions'] = True - return original_forward(*args, **kwargs) - - # TODO: how to unpatch it? - target_layer.forward = MethodType(patched_forward, target_layer) - - def step(self, logits): + if isinstance(output, tuple) and len(output) > 1 and output[1] is not None: + step_attention = output[1].cpu() # (B, n_heads, T0, Ti) + self.last_aligned_attns[buffer_idx] = step_attention[0, head_idx] # (T0, Ti) + + target_layer = tfmr.layers[layer_idx].self_attn + # Register hook and store the handle + target_layer.register_forward_hook(attention_forward_hook) + if hasattr(tfmr, 'config') and hasattr(tfmr.config, 'output_attentions'): + self.original_output_attentions = tfmr.config.output_attentions + tfmr.config.output_attentions = True + + def step(self, logits, next_token=None): """ Emits an AlignmentAnalysisResult into the output queue, and potentially modifies the logits to force an EOS. """ # extract approximate alignment matrix chunk (1 frame at a time after the first chunk) - aligned_attn = self.last_aligned_attn # (N, N) + aligned_attn = torch.stack(self.last_aligned_attns).mean(dim=0) # (N, N) i, j = self.text_tokens_slice if self.curr_frame_pos == 0: # first chunk has conditioning info, text tokens, and BOS token @@ -133,22 +133,46 @@ class AlignmentStreamAnalyzer: last_text_token_duration = A[15:, -3:].sum() # Activations for the final token that last too long are likely hallucinations. - long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 10) # 400ms + long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 5) # 200ms # If there are activations in previous tokens after generation has completed, assume this is a repetition error. - repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5) + alignment_repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5) + + # Track generated tokens for repetition detection + if next_token is not None: + # Convert tensor to scalar if needed + if isinstance(next_token, torch.Tensor): + token_id = next_token.item() if next_token.numel() == 1 else next_token.view(-1)[0].item() + else: + token_id = next_token + self.generated_tokens.append(token_id) + + # Keep only last 8 tokens to prevent memory issues + if len(self.generated_tokens) > 8: + self.generated_tokens = self.generated_tokens[-8:] + + # Check for excessive token repetition (3x same token in a row) + token_repetition = ( + # self.complete and + len(self.generated_tokens) >= 3 and + len(set(self.generated_tokens[-2:])) == 1 + ) + + if token_repetition: + repeated_token = self.generated_tokens[-1] + logger.warning(f"🚨 Detected 2x repetition of token {repeated_token}") + + # Suppress EoS to prevent early termination + if cur_text_posn < S - 3 and S > 5: # Only suppress if text is longer than 5 tokens + logits[..., self.eos_idx] = -2**15 # If a bad ending is detected, force emit EOS by modifying logits # NOTE: this means logits may be inconsistent with latents! - if long_tail or repetition: - logger.warn(f"forcing EOS token, {long_tail=}, {repetition=}") + if long_tail or alignment_repetition or token_repetition: + logger.warning(f"forcing EOS token, {long_tail=}, {alignment_repetition=}, {token_repetition=}") # (Β±2**15 is safe for all dtypes >= 16bit) logits = -(2**15) * torch.ones_like(logits) logits[..., self.eos_idx] = 2**15 - # Suppress EoS to prevent early termination - if cur_text_posn < S - 3: # FIXME: arbitrary - logits[..., self.eos_idx] = -2**15 - self.curr_frame_pos += 1 return logits diff --git a/chatterbox/src/chatterbox/models/t3/modules/__pycache__/cond_enc.cpython-311.pyc b/chatterbox/src/chatterbox/models/t3/modules/__pycache__/cond_enc.cpython-311.pyc deleted file mode 100644 index 245d17a8449324aa924d2da5217efcb4ca7b2d80..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/t3/modules/__pycache__/cond_enc.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/t3/modules/__pycache__/learned_pos_emb.cpython-311.pyc b/chatterbox/src/chatterbox/models/t3/modules/__pycache__/learned_pos_emb.cpython-311.pyc deleted file mode 100644 index 1eb956b90f6522dbc0fe9f1b98d5b22c099095c2..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/t3/modules/__pycache__/learned_pos_emb.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/t3/modules/__pycache__/perceiver.cpython-311.pyc b/chatterbox/src/chatterbox/models/t3/modules/__pycache__/perceiver.cpython-311.pyc deleted file mode 100644 index a860806c99789bfee9578612c16311660cbb4bb6..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/t3/modules/__pycache__/perceiver.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/t3/modules/__pycache__/t3_config.cpython-311.pyc b/chatterbox/src/chatterbox/models/t3/modules/__pycache__/t3_config.cpython-311.pyc deleted file mode 100644 index 4b4b4302eb961d30881318a4db9e8cd879b11ba6..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/t3/modules/__pycache__/t3_config.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/t3/modules/t3_config.py b/chatterbox/src/chatterbox/models/t3/modules/t3_config.py index 2769d835692578c7f8fb0f9bcf6b42daa4b0cd03..55b43903d794465d86445520227e96c34eb0112d 100644 --- a/chatterbox/src/chatterbox/models/t3/modules/t3_config.py +++ b/chatterbox/src/chatterbox/models/t3/modules/t3_config.py @@ -2,26 +2,40 @@ from ..llama_configs import LLAMA_CONFIGS class T3Config: - start_text_token = 255 - stop_text_token = 0 - text_tokens_dict_size = 704 - max_text_tokens = 2048 + def __init__(self, text_tokens_dict_size=704): + self.start_text_token = 255 + self.stop_text_token = 0 + self.text_tokens_dict_size = text_tokens_dict_size + self.max_text_tokens = 2048 - start_speech_token = 6561 - stop_speech_token = 6562 - speech_tokens_dict_size = 8194 - max_speech_tokens = 4096 + self.start_speech_token = 6561 + self.stop_speech_token = 6562 + self.speech_tokens_dict_size = 8194 + self.max_speech_tokens = 4096 - llama_config_name = "Llama_520M" - input_pos_emb = "learned" - speech_cond_prompt_len = 150 + self.llama_config_name = "Llama_520M" + self.input_pos_emb = "learned" + self.speech_cond_prompt_len = 150 - # For T3CondEnc - encoder_type = "voice_encoder" - speaker_embed_size = 256 - use_perceiver_resampler = True - emotion_adv = True + self.encoder_type = "voice_encoder" + self.speaker_embed_size = 256 + self.use_perceiver_resampler = True + self.emotion_adv = True @property def n_channels(self): return LLAMA_CONFIGS[self.llama_config_name]["hidden_size"] + + @property + def is_multilingual(self): + return self.text_tokens_dict_size == 2454 + + @classmethod + def english_only(cls): + """Create configuration for English-only TTS model.""" + return cls(text_tokens_dict_size=704) + + @classmethod + def multilingual(cls): + """Create configuration for multilingual TTS model.""" + return cls(text_tokens_dict_size=2454) \ No newline at end of file diff --git a/chatterbox/src/chatterbox/models/t3/t3.py b/chatterbox/src/chatterbox/models/t3/t3.py index 0165f9f9afd5248ec790cec4311dfe96ea4e0f95..905566c62789d1f836caa48f7ce9385744b9c1db 100644 --- a/chatterbox/src/chatterbox/models/t3/t3.py +++ b/chatterbox/src/chatterbox/models/t3/t3.py @@ -3,12 +3,14 @@ import logging from typing import Union, Optional, List +logger = logging.getLogger(__name__) + from tqdm import tqdm import torch import torch.nn.functional as F from torch import nn, Tensor from transformers import LlamaModel, LlamaConfig -from transformers.generation.logits_process import TopPLogitsWarper, RepetitionPenaltyLogitsProcessor +from transformers.generation.logits_process import TopPLogitsWarper, RepetitionPenaltyLogitsProcessor, MinPLogitsWarper from .modules.learned_pos_emb import LearnedPositionEmbeddings @@ -17,17 +19,12 @@ from .modules.t3_config import T3Config from .llama_configs import LLAMA_CONFIGS from .inference.t3_hf_backend import T3HuggingfaceBackend from .inference.alignment_stream_analyzer import AlignmentStreamAnalyzer +from ..utils import AttrDict logger = logging.getLogger(__name__) -class AttrDict(dict): - def __init__(self, *args, **kwargs): - super(AttrDict, self).__init__(*args, **kwargs) - self.__dict__ = self - - def _ensure_BOT_EOT(text_tokens: Tensor, hp): B = text_tokens.size(0) assert (text_tokens == hp.start_text_token).int().sum() >= B, "missing start_text_token" @@ -44,7 +41,9 @@ class T3(nn.Module): different PE embedding space for speech. """ - def __init__(self, hp=T3Config()): + def __init__(self, hp=None): + if hp is None: + hp = T3Config.english_only() # Default to English-only config for backward compatibility super().__init__() self.hp = hp self.cfg = LlamaConfig(**LLAMA_CONFIGS[hp.llama_config_name]) @@ -89,11 +88,13 @@ class T3(nn.Module): t3_cond: T3Cond, text_tokens: torch.LongTensor, speech_tokens: torch.LongTensor, + cfg_weight: float = 0.0, ): # prepare input embeddings (skip backbone tranformer embeddings) cond_emb = self.prepare_conditioning(t3_cond) # (B, len_cond, dim) text_emb = self.text_emb(text_tokens) # (B, len_text, dim) - text_emb[1].zero_() # CFG uncond + if cfg_weight > 0.0: + text_emb[1].zero_() # CFG uncond speech_emb = self.speech_emb(speech_tokens) # (B, len_speech, dim) if self.hp.input_pos_emb == "learned": @@ -221,10 +222,11 @@ class T3(nn.Module): stop_on_eos=True, do_sample=True, temperature=0.8, - top_p=0.8, + top_p=0.95, + min_p=0.05, length_penalty=1.0, - repetition_penalty=2.0, - cfg_weight=0, + repetition_penalty=1.2, + cfg_weight=0.5, ): """ Args: @@ -244,6 +246,7 @@ class T3(nn.Module): t3_cond=t3_cond, text_tokens=text_tokens, speech_tokens=initial_speech_tokens, + cfg_weight=cfg_weight, ) # In order to use the standard HF generate method, we need to extend some methods to inject our custom logic @@ -254,19 +257,24 @@ class T3(nn.Module): # TODO? synchronize the expensive compile function # with self.compile_lock: if not self.compiled: - # alignment_stream_analyzer = AlignmentStreamAnalyzer( - # self.tfmr, - # None, - # text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)), - # alignment_layer_idx=9, # TODO: hparam or something? - # eos_idx=self.hp.stop_speech_token, - # ) + # Default to None for English models, only create for multilingual + alignment_stream_analyzer = None + if self.hp.is_multilingual: + alignment_stream_analyzer = AlignmentStreamAnalyzer( + self.tfmr, + None, + text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)), + alignment_layer_idx=9, # TODO: hparam or something? + eos_idx=self.hp.stop_speech_token, + ) + assert alignment_stream_analyzer.eos_idx == self.hp.stop_speech_token + patched_model = T3HuggingfaceBackend( config=self.cfg, llama=self.tfmr, speech_enc=self.speech_emb, speech_head=self.speech_head, - # alignment_stream_analyzer=alignment_stream_analyzer, + alignment_stream_analyzer=alignment_stream_analyzer, ) self.patched_model = patched_model self.compiled = True @@ -281,7 +289,7 @@ class T3(nn.Module): # max_new_tokens=max_new_tokens or self.hp.max_speech_tokens, # num_return_sequences=num_return_sequences, # temperature=temperature, - # top_p=top_p, + # min_p=min_p, # length_penalty=length_penalty, # repetition_penalty=repetition_penalty, # do_sample=do_sample, @@ -306,7 +314,9 @@ class T3(nn.Module): # Instantiate the logits processors. top_p_warper = TopPLogitsWarper(top_p=top_p) - repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty) + min_p_warper = MinPLogitsWarper(min_p=min_p) + top_p_warper = TopPLogitsWarper(top_p=top_p) + repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty)) # ---- Initial Forward Pass (no kv_cache yet) ---- output = self.patched_model( @@ -322,21 +332,32 @@ class T3(nn.Module): # ---- Generation Loop using kv_cache ---- for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True): - logits = output.logits[:, -1, :] - - # CFG - logits_cond = logits[0:1] - logits_uncond = logits[1:2] - logits = logits_cond + cfg_weight * (logits_cond - logits_uncond) - logits = logits.squeeze(1) - + logits_step = output.logits[:, -1, :] + # CFG combine β†’ (1, V) + cond = logits_step[0:1, :] + uncond = logits_step[1:2, :] + cfg = torch.as_tensor(cfg_weight, device=cond.device, dtype=cond.dtype) + logits = cond + cfg * (cond - uncond) + + # Apply alignment stream analyzer integrity checks + if self.patched_model.alignment_stream_analyzer is not None: + if logits.dim() == 1: # guard in case something upstream squeezed + logits = logits.unsqueeze(0) # (1, V) + # Pass the last generated token for repetition tracking + last_token = generated_ids[0, -1].item() if len(generated_ids[0]) > 0 else None + logits = self.patched_model.alignment_stream_analyzer.step(logits, next_token=last_token) # (1, V) + + # Apply repetition penalty + ids_for_proc = generated_ids[:1, ...] # batch = 1 + logits = repetition_penalty_processor(ids_for_proc, logits) # expects (B,V) + # Apply temperature scaling. if temperature != 1.0: logits = logits / temperature - - # Apply repetition penalty and top‑p filtering. - logits = repetition_penalty_processor(generated_ids, logits) - logits = top_p_warper(None, logits) + + # Apply min_p and top_p filtering + logits = min_p_warper(ids_for_proc, logits) + logits = top_p_warper(ids_for_proc, logits) # Convert logits to probabilities and sample the next token. probs = torch.softmax(logits, dim=-1) @@ -347,6 +368,7 @@ class T3(nn.Module): # Check for EOS token. if next_token.view(-1) == self.hp.stop_speech_token: + logger.info(f"βœ… EOS token detected! Stopping generation at step {i+1}") break # Get embedding for the new token. diff --git a/chatterbox/src/chatterbox/models/tokenizers/__init__.py b/chatterbox/src/chatterbox/models/tokenizers/__init__.py index 97457e6fd720a10b2c64d2cdbabce9ca5fbf9aad..fdf6d727a14bc20a0ce3a5dd41cf1ce44b6b330a 100644 --- a/chatterbox/src/chatterbox/models/tokenizers/__init__.py +++ b/chatterbox/src/chatterbox/models/tokenizers/__init__.py @@ -1 +1 @@ -from .tokenizer import EnTokenizer +from .tokenizer import EnTokenizer, MTLTokenizer \ No newline at end of file diff --git a/chatterbox/src/chatterbox/models/tokenizers/__pycache__/__init__.cpython-311.pyc b/chatterbox/src/chatterbox/models/tokenizers/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index d24dcfb03b0a46ff8aee7ac1fafbe9ea34b6aa0c..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/tokenizers/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/tokenizers/__pycache__/tokenizer.cpython-311.pyc b/chatterbox/src/chatterbox/models/tokenizers/__pycache__/tokenizer.cpython-311.pyc deleted file mode 100644 index c788f3b1e553dc2e385ee5397d888e262a23334f..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/tokenizers/__pycache__/tokenizer.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/tokenizers/tokenizer.py b/chatterbox/src/chatterbox/models/tokenizers/tokenizer.py index f3536bc24db7d37cca9faff11c064c2c5d7c1c64..84d45d35d2db9c6c576a4af98a7ab91a704af9f2 100644 --- a/chatterbox/src/chatterbox/models/tokenizers/tokenizer.py +++ b/chatterbox/src/chatterbox/models/tokenizers/tokenizer.py @@ -1,7 +1,11 @@ import logging +import json import torch +from pathlib import Path +from unicodedata import category, normalize from tokenizers import Tokenizer +from huggingface_hub import hf_hub_download # Special tokens @@ -28,7 +32,7 @@ class EnTokenizer: text_tokens = torch.IntTensor(text_tokens).unsqueeze(0) return text_tokens - def encode( self, txt: str, verbose=False): + def encode(self, txt: str): """ clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer """ @@ -41,10 +45,269 @@ class EnTokenizer: if isinstance(seq, torch.Tensor): seq = seq.cpu().numpy() - txt: str = self.tokenizer.decode(seq, - skip_special_tokens=False) + txt: str = self.tokenizer.decode(seq, skip_special_tokens=False) txt = txt.replace(' ', '') txt = txt.replace(SPACE, ' ') txt = txt.replace(EOT, '') txt = txt.replace(UNK, '') return txt + + +# Model repository +REPO_ID = "ResembleAI/chatterbox" + +# Global instances for optional dependencies +_kakasi = None +_dicta = None +_russian_stresser = None + + +def is_kanji(c: str) -> bool: + """Check if character is kanji.""" + return 19968 <= ord(c) <= 40959 + + +def is_katakana(c: str) -> bool: + """Check if character is katakana.""" + return 12449 <= ord(c) <= 12538 + + +def hiragana_normalize(text: str) -> str: + """Japanese text normalization: converts kanji to hiragana; katakana remains the same.""" + global _kakasi + + try: + if _kakasi is None: + import pykakasi + _kakasi = pykakasi.kakasi() + + result = _kakasi.convert(text) + out = [] + + for r in result: + inp = r['orig'] + hira = r["hira"] + + # Any kanji in the phrase + if any([is_kanji(c) for c in inp]): + if hira and hira[0] in ["は", "へ"]: # Safety check for empty hira + hira = " " + hira + out.append(hira) + + # All katakana + elif all([is_katakana(c) for c in inp]) if inp else False: # Safety check for empty inp + out.append(r['orig']) + + else: + out.append(inp) + + normalized_text = "".join(out) + + # Decompose Japanese characters for tokenizer compatibility + import unicodedata + normalized_text = unicodedata.normalize('NFKD', normalized_text) + + return normalized_text + + except ImportError: + logger.warning("pykakasi not available - Japanese text processing skipped") + return text + + +def add_hebrew_diacritics(text: str) -> str: + """Hebrew text normalization: adds diacritics to Hebrew text.""" + global _dicta + + try: + if _dicta is None: + from dicta_onnx import Dicta + _dicta = Dicta() + + return _dicta.add_diacritics(text) + + except ImportError: + logger.warning("dicta_onnx not available - Hebrew text processing skipped") + return text + except Exception as e: + logger.warning(f"Hebrew diacritization failed: {e}") + return text + + +def korean_normalize(text: str) -> str: + """Korean text normalization: decompose syllables into Jamo for tokenization.""" + + def decompose_hangul(char): + """Decompose Korean syllable into Jamo components.""" + if not ('\uac00' <= char <= '\ud7af'): + return char + + # Hangul decomposition formula + base = ord(char) - 0xAC00 + initial = chr(0x1100 + base // (21 * 28)) + medial = chr(0x1161 + (base % (21 * 28)) // 28) + final = chr(0x11A7 + base % 28) if base % 28 > 0 else '' + + return initial + medial + final + + # Decompose syllables and normalize punctuation + result = ''.join(decompose_hangul(char) for char in text) + return result.strip() + + +class ChineseCangjieConverter: + """Converts Chinese characters to Cangjie codes for tokenization.""" + + def __init__(self, model_dir=None): + self.word2cj = {} + self.cj2word = {} + self.segmenter = None + self._load_cangjie_mapping(model_dir) + self._init_segmenter() + + def _load_cangjie_mapping(self, model_dir=None): + """Load Cangjie mapping from HuggingFace model repository.""" + try: + cangjie_file = hf_hub_download( + repo_id=REPO_ID, + filename="Cangjie5_TC.json", + cache_dir=model_dir + ) + + with open(cangjie_file, "r", encoding="utf-8") as fp: + data = json.load(fp) + + for entry in data: + word, code = entry.split("\t")[:2] + self.word2cj[word] = code + if code not in self.cj2word: + self.cj2word[code] = [word] + else: + self.cj2word[code].append(word) + + except Exception as e: + logger.warning(f"Could not load Cangjie mapping: {e}") + + def _init_segmenter(self): + """Initialize pkuseg segmenter.""" + try: + from spacy_pkuseg import pkuseg + self.segmenter = pkuseg() + except ImportError: + logger.warning("pkuseg not available - Chinese segmentation will be skipped") + self.segmenter = None + + def _cangjie_encode(self, glyph: str): + """Encode a single Chinese glyph to Cangjie code.""" + normed_glyph = glyph + code = self.word2cj.get(normed_glyph, None) + if code is None: # e.g. Japanese hiragana + return None + index = self.cj2word[code].index(normed_glyph) + index = str(index) if index > 0 else "" + return code + str(index) + + + + def __call__(self, text): + """Convert Chinese characters in text to Cangjie tokens.""" + output = [] + if self.segmenter is not None: + segmented_words = self.segmenter.cut(text) + full_text = " ".join(segmented_words) + else: + full_text = text + + for t in full_text: + if category(t) == "Lo": + cangjie = self._cangjie_encode(t) + if cangjie is None: + output.append(t) + continue + code = [] + for c in cangjie: + code.append(f"[cj_{c}]") + code.append("[cj_.]") + code = "".join(code) + output.append(code) + else: + output.append(t) + return "".join(output) + + +def add_russian_stress(text: str) -> str: + """Russian text normalization: adds stress marks to Russian text.""" + global _russian_stresser + + try: + if _russian_stresser is None: + from russian_text_stresser.text_stresser import RussianTextStresser + _russian_stresser = RussianTextStresser() + + return _russian_stresser.stress_text(text) + + except ImportError: + logger.warning("russian_text_stresser not available - Russian stress labeling skipped") + return text + except Exception as e: + logger.warning(f"Russian stress labeling failed: {e}") + return text + + +class MTLTokenizer: + def __init__(self, vocab_file_path): + self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path) + model_dir = Path(vocab_file_path).parent + self.cangjie_converter = ChineseCangjieConverter(model_dir) + self.check_vocabset_sot_eot() + + def check_vocabset_sot_eot(self): + voc = self.tokenizer.get_vocab() + assert SOT in voc + assert EOT in voc + + def preprocess_text(self, raw_text: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True): + """ + Text preprocessor that handles lowercase conversion and NFKD normalization. + """ + preprocessed_text = raw_text + if lowercase: + preprocessed_text = preprocessed_text.lower() + if nfkd_normalize: + preprocessed_text = normalize("NFKD", preprocessed_text) + + return preprocessed_text + + def text_to_tokens(self, text: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True): + text_tokens = self.encode(text, language_id=language_id, lowercase=lowercase, nfkd_normalize=nfkd_normalize) + text_tokens = torch.IntTensor(text_tokens).unsqueeze(0) + return text_tokens + + def encode(self, txt: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True): + txt = self.preprocess_text(txt, language_id=language_id, lowercase=lowercase, nfkd_normalize=nfkd_normalize) + + # Language-specific text processing + if language_id == 'zh': + txt = self.cangjie_converter(txt) + elif language_id == 'ja': + txt = hiragana_normalize(txt) + elif language_id == 'he': + txt = add_hebrew_diacritics(txt) + elif language_id == 'ko': + txt = korean_normalize(txt) + elif language_id == 'ru': + txt = add_russian_stress(txt) + + # Prepend language token + if language_id: + txt = f"[{language_id.lower()}]{txt}" + + txt = txt.replace(' ', SPACE) + return self.tokenizer.encode(txt).ids + + def decode(self, seq): + if isinstance(seq, torch.Tensor): + seq = seq.cpu().numpy() + + txt = self.tokenizer.decode(seq, skip_special_tokens=False) + txt = txt.replace(' ', '').replace(SPACE, ' ').replace(EOT, '').replace(UNK, '') + return txt diff --git a/chatterbox/src/chatterbox/models/utils.py b/chatterbox/src/chatterbox/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a4abce5d072e1a7d3699051130913661a5605af0 --- /dev/null +++ b/chatterbox/src/chatterbox/models/utils.py @@ -0,0 +1,4 @@ +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self diff --git a/chatterbox/src/chatterbox/models/voice_encoder/__pycache__/__init__.cpython-311.pyc b/chatterbox/src/chatterbox/models/voice_encoder/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 682c2c27bf3ff22332f5e82486af4269d4d5e02e..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/voice_encoder/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/voice_encoder/__pycache__/config.cpython-311.pyc b/chatterbox/src/chatterbox/models/voice_encoder/__pycache__/config.cpython-311.pyc deleted file mode 100644 index 605bca3b567e4d575e5f6c37e25bff163d0ce44d..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/voice_encoder/__pycache__/config.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/voice_encoder/__pycache__/melspec.cpython-311.pyc b/chatterbox/src/chatterbox/models/voice_encoder/__pycache__/melspec.cpython-311.pyc deleted file mode 100644 index dac17905b3483798277dd1f6a0fefc22c01199c0..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/voice_encoder/__pycache__/melspec.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/models/voice_encoder/__pycache__/voice_encoder.cpython-311.pyc b/chatterbox/src/chatterbox/models/voice_encoder/__pycache__/voice_encoder.cpython-311.pyc deleted file mode 100644 index 827f7723385acabc4e90420e556e8889c0af89db..0000000000000000000000000000000000000000 Binary files a/chatterbox/src/chatterbox/models/voice_encoder/__pycache__/voice_encoder.cpython-311.pyc and /dev/null differ diff --git a/chatterbox/src/chatterbox/mtl_tts.py b/chatterbox/src/chatterbox/mtl_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..2c9cf0524f9523cf7c427d85f96ae1254dfcf410 --- /dev/null +++ b/chatterbox/src/chatterbox/mtl_tts.py @@ -0,0 +1,301 @@ +from dataclasses import dataclass +from pathlib import Path +import os + +import librosa +import torch +import perth +import torch.nn.functional as F +from safetensors.torch import load_file as load_safetensors +from huggingface_hub import snapshot_download + +from .models.t3 import T3 +from .models.t3.modules.t3_config import T3Config +from .models.s3tokenizer import S3_SR, drop_invalid_tokens +from .models.s3gen import S3GEN_SR, S3Gen +from .models.tokenizers import MTLTokenizer +from .models.voice_encoder import VoiceEncoder +from .models.t3.modules.cond_enc import T3Cond + + +REPO_ID = "ResembleAI/chatterbox" + +# Supported languages for the multilingual model +SUPPORTED_LANGUAGES = { + "ar": "Arabic", + "da": "Danish", + "de": "German", + "el": "Greek", + "en": "English", + "es": "Spanish", + "fi": "Finnish", + "fr": "French", + "he": "Hebrew", + "hi": "Hindi", + "it": "Italian", + "ja": "Japanese", + "ko": "Korean", + "ms": "Malay", + "nl": "Dutch", + "no": "Norwegian", + "pl": "Polish", + "pt": "Portuguese", + "ru": "Russian", + "sv": "Swedish", + "sw": "Swahili", + "tr": "Turkish", + "zh": "Chinese", +} + + +def punc_norm(text: str) -> str: + """ + Quick cleanup func for punctuation from LLMs or + containing chars not seen often in the dataset + """ + if len(text) == 0: + return "You need to add some text for me to talk." + + # Capitalise first letter + if text[0].islower(): + text = text[0].upper() + text[1:] + + # Remove multiple space chars + text = " ".join(text.split()) + + # Replace uncommon/llm punc + punc_to_replace = [ + ("...", ", "), + ("…", ", "), + (":", ","), + (" - ", ", "), + (";", ", "), + ("β€”", "-"), + ("–", "-"), + (" ,", ","), + ("β€œ", "\""), + ("”", "\""), + ("β€˜", "'"), + ("’", "'"), + ] + for old_char_sequence, new_char in punc_to_replace: + text = text.replace(old_char_sequence, new_char) + + # Add full stop if no ending punc + text = text.rstrip(" ") + sentence_enders = {".", "!", "?", "-", ",","、",",","。","?","!"} + if not any(text.endswith(p) for p in sentence_enders): + text += "." + + return text + + +@dataclass +class Conditionals: + """ + Conditionals for T3 and S3Gen + - T3 conditionals: + - speaker_emb + - clap_emb + - cond_prompt_speech_tokens + - cond_prompt_speech_emb + - emotion_adv + - S3Gen conditionals: + - prompt_token + - prompt_token_len + - prompt_feat + - prompt_feat_len + - embedding + """ + t3: T3Cond + gen: dict + + def to(self, device): + self.t3 = self.t3.to(device=device) + for k, v in self.gen.items(): + if torch.is_tensor(v): + self.gen[k] = v.to(device=device) + return self + + def save(self, fpath: Path): + arg_dict = dict( + t3=self.t3.__dict__, + gen=self.gen + ) + torch.save(arg_dict, fpath) + + @classmethod + def load(cls, fpath, map_location="cpu"): + kwargs = torch.load(fpath, map_location=map_location, weights_only=True) + return cls(T3Cond(**kwargs['t3']), kwargs['gen']) + + +class ChatterboxMultilingualTTS: + ENC_COND_LEN = 6 * S3_SR + DEC_COND_LEN = 10 * S3GEN_SR + + def __init__( + self, + t3: T3, + s3gen: S3Gen, + ve: VoiceEncoder, + tokenizer: MTLTokenizer, + device: str, + conds: Conditionals = None, + ): + self.sr = S3GEN_SR # sample rate of synthesized audio + self.t3 = t3 + self.s3gen = s3gen + self.ve = ve + self.tokenizer = tokenizer + self.device = device + self.conds = conds + self.watermarker = perth.PerthImplicitWatermarker() + + @classmethod + def get_supported_languages(cls): + """Return dictionary of supported language codes and names.""" + return SUPPORTED_LANGUAGES.copy() + + @classmethod + def from_local(cls, ckpt_dir, device) -> 'ChatterboxMultilingualTTS': + ckpt_dir = Path(ckpt_dir) + + ve = VoiceEncoder() + ve.load_state_dict( + torch.load(ckpt_dir / "ve.pt", weights_only=True) + ) + ve.to(device).eval() + + t3 = T3(T3Config.multilingual()) + t3_state = load_safetensors(ckpt_dir / "t3_mtl23ls_v2.safetensors") + if "model" in t3_state.keys(): + t3_state = t3_state["model"][0] + t3.load_state_dict(t3_state) + t3.to(device).eval() + + s3gen = S3Gen() + s3gen.load_state_dict( + torch.load(ckpt_dir / "s3gen.pt", weights_only=True) + ) + s3gen.to(device).eval() + + tokenizer = MTLTokenizer( + str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json") + ) + + conds = None + if (builtin_voice := ckpt_dir / "conds.pt").exists(): + conds = Conditionals.load(builtin_voice).to(device) + + return cls(t3, s3gen, ve, tokenizer, device, conds=conds) + + @classmethod + def from_pretrained(cls, device: torch.device) -> 'ChatterboxMultilingualTTS': + ckpt_dir = Path( + snapshot_download( + repo_id=REPO_ID, + repo_type="model", + revision="main", + allow_patterns=["ve.pt", "t3_mtl23ls_v2.safetensors", "s3gen.pt", "grapheme_mtl_merged_expanded_v1.json", "conds.pt", "Cangjie5_TC.json"], + token=os.getenv("HF_TOKEN"), + ) + ) + return cls.from_local(ckpt_dir, device) + + def prepare_conditionals(self, wav_fpath, exaggeration=0.5): + ## Load reference wav + s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR) + + ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR) + + s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN] + s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device) + + # Speech cond prompt tokens + t3_cond_prompt_tokens = None + if plen := self.t3.hp.speech_cond_prompt_len: + s3_tokzr = self.s3gen.tokenizer + t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen) + t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device) + + # Voice-encoder speaker embedding + ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR)) + ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device) + + t3_cond = T3Cond( + speaker_emb=ve_embed, + cond_prompt_speech_tokens=t3_cond_prompt_tokens, + emotion_adv=exaggeration * torch.ones(1, 1, 1), + ).to(device=self.device) + self.conds = Conditionals(t3_cond, s3gen_ref_dict) + + def generate( + self, + text, + language_id, + audio_prompt_path=None, + exaggeration=0.5, + cfg_weight=0.5, + temperature=0.8, + repetition_penalty=2.0, + min_p=0.05, + top_p=1.0, + ): + # Validate language_id + if language_id and language_id.lower() not in SUPPORTED_LANGUAGES: + supported_langs = ", ".join(SUPPORTED_LANGUAGES.keys()) + raise ValueError( + f"Unsupported language_id '{language_id}'. " + f"Supported languages: {supported_langs}" + ) + + if audio_prompt_path: + self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration) + else: + assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`" + + # Update exaggeration if needed + if float(exaggeration) != float(self.conds.t3.emotion_adv[0, 0, 0].item()): + _cond: T3Cond = self.conds.t3 + self.conds.t3 = T3Cond( + speaker_emb=_cond.speaker_emb, + cond_prompt_speech_tokens=_cond.cond_prompt_speech_tokens, + emotion_adv=exaggeration * torch.ones(1, 1, 1), + ).to(device=self.device) + + # Norm and tokenize text + text = punc_norm(text) + text_tokens = self.tokenizer.text_to_tokens(text, language_id=language_id.lower() if language_id else None).to(self.device) + text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG + + sot = self.t3.hp.start_text_token + eot = self.t3.hp.stop_text_token + text_tokens = F.pad(text_tokens, (1, 0), value=sot) + text_tokens = F.pad(text_tokens, (0, 1), value=eot) + + with torch.inference_mode(): + speech_tokens = self.t3.inference( + t3_cond=self.conds.t3, + text_tokens=text_tokens, + max_new_tokens=1000, # TODO: use the value in config + temperature=temperature, + cfg_weight=cfg_weight, + repetition_penalty=repetition_penalty, + min_p=min_p, + top_p=top_p, + ) + # Extract only the conditional batch. + speech_tokens = speech_tokens[0] + + # TODO: output becomes 1D + speech_tokens = drop_invalid_tokens(speech_tokens) + speech_tokens = speech_tokens.to(self.device) + + wav, _ = self.s3gen.inference( + speech_tokens=speech_tokens, + ref_dict=self.conds.gen, + ) + wav = wav.squeeze(0).detach().cpu().numpy() + watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr) + return torch.from_numpy(watermarked_wav).unsqueeze(0) diff --git a/chatterbox/src/chatterbox/tts.py b/chatterbox/src/chatterbox/tts.py index ae838c21be9c8f2e933ee450658e672bb85d4b5d..6d9b5ad54afb6d661158a0c6f3a5d8f373fa72b8 100644 --- a/chatterbox/src/chatterbox/tts.py +++ b/chatterbox/src/chatterbox/tts.py @@ -6,6 +6,7 @@ import torch import perth import torch.nn.functional as F from huggingface_hub import hf_hub_download +from safetensors.torch import load_file from .models.t3 import T3 from .models.s3tokenizer import S3_SR, drop_invalid_tokens @@ -96,6 +97,8 @@ class Conditionals: @classmethod def load(cls, fpath, map_location="cpu"): + if isinstance(map_location, str): + map_location = torch.device(map_location) kwargs = torch.load(fpath, map_location=map_location, weights_only=True) return cls(T3Cond(**kwargs['t3']), kwargs['gen']) @@ -126,14 +129,20 @@ class ChatterboxTTS: def from_local(cls, ckpt_dir, device) -> 'ChatterboxTTS': ckpt_dir = Path(ckpt_dir) + # Always load to CPU first for non-CUDA devices to handle CUDA-saved models + if device in ["cpu", "mps"]: + map_location = torch.device('cpu') + else: + map_location = None + ve = VoiceEncoder() ve.load_state_dict( - torch.load(ckpt_dir / "ve.pt") + load_file(ckpt_dir / "ve.safetensors") ) ve.to(device).eval() t3 = T3() - t3_state = torch.load(ckpt_dir / "t3_cfg.pt") + t3_state = load_file(ckpt_dir / "t3_cfg.safetensors") if "model" in t3_state.keys(): t3_state = t3_state["model"][0] t3.load_state_dict(t3_state) @@ -141,7 +150,7 @@ class ChatterboxTTS: s3gen = S3Gen() s3gen.load_state_dict( - torch.load(ckpt_dir / "s3gen.pt") + load_file(ckpt_dir / "s3gen.safetensors"), strict=False ) s3gen.to(device).eval() @@ -151,13 +160,21 @@ class ChatterboxTTS: conds = None if (builtin_voice := ckpt_dir / "conds.pt").exists(): - conds = Conditionals.load(builtin_voice).to(device) + conds = Conditionals.load(builtin_voice, map_location=map_location).to(device) return cls(t3, s3gen, ve, tokenizer, device, conds=conds) @classmethod def from_pretrained(cls, device) -> 'ChatterboxTTS': - for fpath in ["ve.pt", "t3_cfg.pt", "s3gen.pt", "tokenizer.json", "conds.pt"]: + # Check if MPS is available on macOS + if device == "mps" and not torch.backends.mps.is_available(): + if not torch.backends.mps.is_built(): + print("MPS not available because the current PyTorch install was not built with MPS enabled.") + else: + print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.") + device = "cpu" + + for fpath in ["ve.safetensors", "t3_cfg.safetensors", "s3gen.safetensors", "tokenizer.json", "conds.pt"]: local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath) return cls.from_local(Path(local_path).parent, device) @@ -191,6 +208,9 @@ class ChatterboxTTS: def generate( self, text, + repetition_penalty=1.2, + min_p=0.05, + top_p=1.0, audio_prompt_path=None, exaggeration=0.5, cfg_weight=0.5, @@ -213,7 +233,9 @@ class ChatterboxTTS: # Norm and tokenize text text = punc_norm(text) text_tokens = self.tokenizer.text_to_tokens(text).to(self.device) - text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG + + if cfg_weight > 0.0: + text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG sot = self.t3.hp.start_text_token eot = self.t3.hp.stop_text_token @@ -227,12 +249,18 @@ class ChatterboxTTS: max_new_tokens=1000, # TODO: use the value in config temperature=temperature, cfg_weight=cfg_weight, + repetition_penalty=repetition_penalty, + min_p=min_p, + top_p=top_p, ) # Extract only the conditional batch. speech_tokens = speech_tokens[0] # TODO: output becomes 1D speech_tokens = drop_invalid_tokens(speech_tokens) + + speech_tokens = speech_tokens[speech_tokens < 6561] + speech_tokens = speech_tokens.to(self.device) wav, _ = self.s3gen.inference( @@ -241,4 +269,4 @@ class ChatterboxTTS: ) wav = wav.squeeze(0).detach().cpu().numpy() watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr) - return torch.from_numpy(watermarked_wav).unsqueeze(0) + return torch.from_numpy(watermarked_wav).unsqueeze(0) \ No newline at end of file diff --git a/chatterbox/src/chatterbox/utils.py b/chatterbox/src/chatterbox/utils.py deleted file mode 100644 index 6b06428049254576d7b0281beac3daf3bfdec515..0000000000000000000000000000000000000000 --- a/chatterbox/src/chatterbox/utils.py +++ /dev/null @@ -1,15 +0,0 @@ -# import numpy as np -# -# -# def trim_silence(wav, speech_timestamps, sr): -# """TODO: fading""" -# if len(speech_timestamps) == 0: -# return wav # WARNING: no speech detected, returning original wav -# segs = [] -# for segment in speech_timestamps: -# start_s, end_s = segment['start'], segment['end'] -# start = int(start_s * sr) -# end = int(end_s * sr) -# seg = wav[start: end] -# segs.append(seg) -# return np.concatenate(segs) diff --git a/chatterbox/src/chatterbox/vc.py b/chatterbox/src/chatterbox/vc.py index 7fa00bfbfef4791a69b58bcb75c29d5aeebd3cd3..a9c32ed3567192f07eee68a78c0517c1324892fc 100644 --- a/chatterbox/src/chatterbox/vc.py +++ b/chatterbox/src/chatterbox/vc.py @@ -4,6 +4,7 @@ import librosa import torch import perth from huggingface_hub import hf_hub_download +from safetensors.torch import load_file from .models.s3tokenizer import S3_SR from .models.s3gen import S3GEN_SR, S3Gen @@ -37,14 +38,21 @@ class ChatterboxVC: @classmethod def from_local(cls, ckpt_dir, device) -> 'ChatterboxVC': ckpt_dir = Path(ckpt_dir) + + # Always load to CPU first for non-CUDA devices to handle CUDA-saved models + if device in ["cpu", "mps"]: + map_location = torch.device('cpu') + else: + map_location = None + ref_dict = None if (builtin_voice := ckpt_dir / "conds.pt").exists(): - states = torch.load(builtin_voice) + states = torch.load(builtin_voice, map_location=map_location) ref_dict = states['gen'] s3gen = S3Gen() s3gen.load_state_dict( - torch.load(ckpt_dir / "s3gen.pt") + load_file(ckpt_dir / "s3gen.safetensors"), strict=False ) s3gen.to(device).eval() @@ -52,7 +60,15 @@ class ChatterboxVC: @classmethod def from_pretrained(cls, device) -> 'ChatterboxVC': - for fpath in ["s3gen.pt", "conds.pt"]: + # Check if MPS is available on macOS + if device == "mps" and not torch.backends.mps.is_available(): + if not torch.backends.mps.is_built(): + print("MPS not available because the current PyTorch install was not built with MPS enabled.") + else: + print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.") + device = "cpu" + + for fpath in ["s3gen.safetensors", "conds.pt"]: local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath) return cls.from_local(Path(local_path).parent, device) @@ -85,4 +101,4 @@ class ChatterboxVC: ) wav = wav.squeeze(0).detach().cpu().numpy() watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr) - return torch.from_numpy(watermarked_wav).unsqueeze(0) + return torch.from_numpy(watermarked_wav).unsqueeze(0) \ No newline at end of file diff --git a/current_structure.txt b/current_structure.txt new file mode 100644 index 0000000000000000000000000000000000000000..55690ddabc1b78c979d887afb51c1218ca632e1d --- /dev/null +++ b/current_structure.txt @@ -0,0 +1,158 @@ +chatterbox/src/chatterbox: +__init__.py +models +__pycache__ +tts.py +utils.py +vc.py + +chatterbox/src/chatterbox/models: +s3gen +s3tokenizer +t3 +tokenizers +voice_encoder + +chatterbox/src/chatterbox/models/s3gen: +const.py +decoder.py +f0_predictor.py +flow_matching.py +flow.py +hifigan.py +__init__.py +matcha +__pycache__ +s3gen.py +transformer +utils +xvector.py + +chatterbox/src/chatterbox/models/s3gen/matcha: +decoder.py +flow_matching.py +__pycache__ +text_encoder.py +transformer.py + +chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__: +decoder.cpython-311.pyc +flow_matching.cpython-311.pyc +transformer.cpython-311.pyc + +chatterbox/src/chatterbox/models/s3gen/__pycache__: +const.cpython-311.pyc +decoder.cpython-311.pyc +f0_predictor.cpython-311.pyc +flow.cpython-311.pyc +flow_matching.cpython-311.pyc +hifigan.cpython-311.pyc +__init__.cpython-311.pyc +s3gen.cpython-311.pyc +xvector.cpython-311.pyc + +chatterbox/src/chatterbox/models/s3gen/transformer: +activation.py +attention.py +convolution.py +embedding.py +encoder_layer.py +__init__.py +positionwise_feed_forward.py +__pycache__ +subsampling.py +upsample_encoder.py + +chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__: +activation.cpython-311.pyc +attention.cpython-311.pyc +convolution.cpython-311.pyc +embedding.cpython-311.pyc +encoder_layer.cpython-311.pyc +__init__.cpython-311.pyc +positionwise_feed_forward.cpython-311.pyc +subsampling.cpython-311.pyc +upsample_encoder.cpython-311.pyc + +chatterbox/src/chatterbox/models/s3gen/utils: +class_utils.py +mask.py +mel.py +__pycache__ + +chatterbox/src/chatterbox/models/s3gen/utils/__pycache__: +class_utils.cpython-311.pyc +mask.cpython-311.pyc +mel.cpython-311.pyc + +chatterbox/src/chatterbox/models/s3tokenizer: +__init__.py +__pycache__ +s3tokenizer.py + +chatterbox/src/chatterbox/models/s3tokenizer/__pycache__: +__init__.cpython-311.pyc +s3tokenizer.cpython-311.pyc + +chatterbox/src/chatterbox/models/t3: +inference +__init__.py +llama_configs.py +modules +__pycache__ +t3.py + +chatterbox/src/chatterbox/models/t3/inference: +alignment_stream_analyzer.py +__pycache__ +t3_hf_backend.py + +chatterbox/src/chatterbox/models/t3/inference/__pycache__: +alignment_stream_analyzer.cpython-311.pyc +t3_hf_backend.cpython-311.pyc + +chatterbox/src/chatterbox/models/t3/modules: +cond_enc.py +learned_pos_emb.py +perceiver.py +__pycache__ +t3_config.py + +chatterbox/src/chatterbox/models/t3/modules/__pycache__: +cond_enc.cpython-311.pyc +learned_pos_emb.cpython-311.pyc +perceiver.cpython-311.pyc +t3_config.cpython-311.pyc + +chatterbox/src/chatterbox/models/t3/__pycache__: +__init__.cpython-311.pyc +llama_configs.cpython-311.pyc +t3.cpython-311.pyc + +chatterbox/src/chatterbox/models/tokenizers: +__init__.py +__pycache__ +tokenizer.py + +chatterbox/src/chatterbox/models/tokenizers/__pycache__: +__init__.cpython-311.pyc +tokenizer.cpython-311.pyc + +chatterbox/src/chatterbox/models/voice_encoder: +config.py +__init__.py +melspec.py +__pycache__ +voice_encoder.py + +chatterbox/src/chatterbox/models/voice_encoder/__pycache__: +config.cpython-311.pyc +__init__.cpython-311.pyc +melspec.cpython-311.pyc +voice_encoder.cpython-311.pyc + +chatterbox/src/chatterbox/__pycache__: +__init__.cpython-311.pyc +tts.cpython-311.pyc +utils.cpython-311.pyc +vc.cpython-311.pyc diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..035f50aafee6373ad37d722c6500473ec9abd9fc --- /dev/null +++ b/main.py @@ -0,0 +1,6 @@ +def main(): + print("Hello from ta-chatterbox!") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..acc6ebebf501d64f99de8cad6ee60650306df2d6 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,7 @@ +[project] +name = "ta-chatterbox" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.10" +dependencies = [] diff --git a/requirements.txt b/requirements.txt index afcc82ea217279792dfe236ab3d58d79b9d6c942..6f96856f1622b52a2c780c8e987216ee4ded8fac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ gradio -numpy==1.26.0 +numpy<1.26.0 resampy==0.4.3 -librosa==0.10.0 +librosa==0.11.0 s3tokenizer transformers==4.46.3 @@ -9,4 +9,8 @@ diffusers==0.29.0 omegaconf==2.3.0 resemble-perth==1.0.1 silero-vad==5.1.2 -conformer==0.3.2 \ No newline at end of file +conformer==0.3.2 +safetensors==0.5.3 +spacy-pkuseg +pykakasi==2.3.0 +git+https://github.com/Vuizur/add-stress-to-epub