j commited on
Commit
f98c92f
·
1 Parent(s): af25078

update to upstream chatterbox implementation, fixes token filtering/clamping

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +1 -0
  2. chatterbox/src/chatterbox/__init__.py +15 -0
  3. chatterbox/src/chatterbox/__pycache__/__init__.cpython-311.pyc +0 -0
  4. chatterbox/src/chatterbox/__pycache__/tts.cpython-311.pyc +0 -0
  5. chatterbox/src/chatterbox/__pycache__/utils.cpython-311.pyc +0 -0
  6. chatterbox/src/chatterbox/__pycache__/vc.cpython-311.pyc +0 -0
  7. chatterbox/src/chatterbox/models/__init__.py +0 -0
  8. chatterbox/src/chatterbox/models/s3gen/__pycache__/__init__.cpython-311.pyc +0 -0
  9. chatterbox/src/chatterbox/models/s3gen/__pycache__/const.cpython-311.pyc +0 -0
  10. chatterbox/src/chatterbox/models/s3gen/__pycache__/decoder.cpython-311.pyc +0 -0
  11. chatterbox/src/chatterbox/models/s3gen/__pycache__/f0_predictor.cpython-311.pyc +0 -0
  12. chatterbox/src/chatterbox/models/s3gen/__pycache__/flow.cpython-311.pyc +0 -0
  13. chatterbox/src/chatterbox/models/s3gen/__pycache__/flow_matching.cpython-311.pyc +0 -0
  14. chatterbox/src/chatterbox/models/s3gen/__pycache__/hifigan.cpython-311.pyc +0 -0
  15. chatterbox/src/chatterbox/models/s3gen/__pycache__/s3gen.cpython-311.pyc +0 -0
  16. chatterbox/src/chatterbox/models/s3gen/__pycache__/xvector.cpython-311.pyc +0 -0
  17. chatterbox/src/chatterbox/models/s3gen/configs.py +10 -0
  18. chatterbox/src/chatterbox/models/s3gen/flow.py +89 -41
  19. chatterbox/src/chatterbox/models/s3gen/flow_matching.py +1 -11
  20. chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/decoder.cpython-311.pyc +0 -0
  21. chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/flow_matching.cpython-311.pyc +0 -0
  22. chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/transformer.cpython-311.pyc +0 -0
  23. chatterbox/src/chatterbox/models/s3gen/s3gen.py +2 -9
  24. chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/__init__.cpython-311.pyc +0 -0
  25. chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/activation.cpython-311.pyc +0 -0
  26. chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/attention.cpython-311.pyc +0 -0
  27. chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/convolution.cpython-311.pyc +0 -0
  28. chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/embedding.cpython-311.pyc +0 -0
  29. chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/encoder_layer.cpython-311.pyc +0 -0
  30. chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc +0 -0
  31. chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/subsampling.cpython-311.pyc +0 -0
  32. chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/upsample_encoder.cpython-311.pyc +0 -0
  33. chatterbox/src/chatterbox/models/s3gen/utils/__pycache__/class_utils.cpython-311.pyc +0 -0
  34. chatterbox/src/chatterbox/models/s3gen/utils/__pycache__/mask.cpython-311.pyc +0 -0
  35. chatterbox/src/chatterbox/models/s3gen/utils/__pycache__/mel.cpython-311.pyc +0 -0
  36. chatterbox/src/chatterbox/models/s3gen/utils/mel.py +8 -4
  37. chatterbox/src/chatterbox/models/s3tokenizer/__pycache__/__init__.cpython-311.pyc +0 -0
  38. chatterbox/src/chatterbox/models/s3tokenizer/__pycache__/s3tokenizer.cpython-311.pyc +0 -0
  39. chatterbox/src/chatterbox/models/t3/__pycache__/__init__.cpython-311.pyc +0 -0
  40. chatterbox/src/chatterbox/models/t3/__pycache__/llama_configs.cpython-311.pyc +0 -0
  41. chatterbox/src/chatterbox/models/t3/__pycache__/t3.cpython-311.pyc +0 -0
  42. chatterbox/src/chatterbox/models/t3/inference/__pycache__/alignment_stream_analyzer.cpython-311.pyc +0 -0
  43. chatterbox/src/chatterbox/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc +0 -0
  44. chatterbox/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py +56 -32
  45. chatterbox/src/chatterbox/models/t3/modules/__pycache__/cond_enc.cpython-311.pyc +0 -0
  46. chatterbox/src/chatterbox/models/t3/modules/__pycache__/learned_pos_emb.cpython-311.pyc +0 -0
  47. chatterbox/src/chatterbox/models/t3/modules/__pycache__/perceiver.cpython-311.pyc +0 -0
  48. chatterbox/src/chatterbox/models/t3/modules/__pycache__/t3_config.cpython-311.pyc +0 -0
  49. chatterbox/src/chatterbox/models/t3/modules/t3_config.py +30 -16
  50. chatterbox/src/chatterbox/models/t3/t3.py +56 -34
README.md CHANGED
@@ -8,6 +8,7 @@ sdk_version: 5.29.0
8
  app_file: app.py
9
  pinned: false
10
  short_description: Expressive Zeroshot TTS
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: app.py
9
  pinned: false
10
  short_description: Expressive Zeroshot TTS
11
+ python_version: 3.10
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
chatterbox/src/chatterbox/__init__.py CHANGED
@@ -1,2 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from .tts import ChatterboxTTS
2
  from .vc import ChatterboxVC
 
 
1
+ try:
2
+ from importlib.metadata import version, PackageNotFoundError
3
+ try:
4
+ __version__ = version("chatterbox-tts")
5
+ except PackageNotFoundError:
6
+ __version__ = "0.1.4" # Default fallback version
7
+ except ImportError:
8
+ from importlib_metadata import version, PackageNotFoundError # For Python <3.8
9
+ try:
10
+ __version__ = version("chatterbox-tts")
11
+ except PackageNotFoundError:
12
+ __version__ = "0.1.4"
13
+
14
+
15
  from .tts import ChatterboxTTS
16
  from .vc import ChatterboxVC
17
+ from .mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
chatterbox/src/chatterbox/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (275 Bytes)
 
chatterbox/src/chatterbox/__pycache__/tts.cpython-311.pyc DELETED
Binary file (13.3 kB)
 
chatterbox/src/chatterbox/__pycache__/utils.cpython-311.pyc DELETED
Binary file (858 Bytes)
 
chatterbox/src/chatterbox/__pycache__/vc.cpython-311.pyc DELETED
Binary file (5.44 kB)
 
chatterbox/src/chatterbox/models/__init__.py ADDED
File without changes
chatterbox/src/chatterbox/models/s3gen/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (294 Bytes)
 
chatterbox/src/chatterbox/models/s3gen/__pycache__/const.cpython-311.pyc DELETED
Binary file (190 Bytes)
 
chatterbox/src/chatterbox/models/s3gen/__pycache__/decoder.cpython-311.pyc DELETED
Binary file (16.9 kB)
 
chatterbox/src/chatterbox/models/s3gen/__pycache__/f0_predictor.cpython-311.pyc DELETED
Binary file (2.7 kB)
 
chatterbox/src/chatterbox/models/s3gen/__pycache__/flow.cpython-311.pyc DELETED
Binary file (13.7 kB)
 
chatterbox/src/chatterbox/models/s3gen/__pycache__/flow_matching.cpython-311.pyc DELETED
Binary file (13.3 kB)
 
chatterbox/src/chatterbox/models/s3gen/__pycache__/hifigan.cpython-311.pyc DELETED
Binary file (26.3 kB)
 
chatterbox/src/chatterbox/models/s3gen/__pycache__/s3gen.cpython-311.pyc DELETED
Binary file (13.7 kB)
 
chatterbox/src/chatterbox/models/s3gen/__pycache__/xvector.cpython-311.pyc DELETED
Binary file (24 kB)
 
chatterbox/src/chatterbox/models/s3gen/configs.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..utils import AttrDict
2
+
3
+ CFM_PARAMS = AttrDict({
4
+ "sigma_min": 1e-06,
5
+ "solver": "euler",
6
+ "t_scheduler": "cosine",
7
+ "training_cfg_rate": 0.2,
8
+ "inference_cfg_rate": 0.7,
9
+ "reg_loss_type": "l1"
10
+ })
chatterbox/src/chatterbox/models/s3gen/flow.py CHANGED
@@ -14,32 +14,54 @@
14
  import logging
15
  import random
16
  from typing import Dict, Optional
 
 
17
  import torch
18
  import torch.nn as nn
19
  from torch.nn import functional as F
20
- from omegaconf import DictConfig
21
  from .utils.mask import make_pad_mask
 
22
 
23
 
24
  class MaskedDiffWithXvec(torch.nn.Module):
25
- def __init__(self,
26
- input_size: int = 512,
27
- output_size: int = 80,
28
- spk_embed_dim: int = 192,
29
- output_type: str = "mel",
30
- vocab_size: int = 4096,
31
- input_frame_rate: int = 50,
32
- only_mask_loss: bool = True,
33
- encoder: torch.nn.Module = None,
34
- length_regulator: torch.nn.Module = None,
35
- decoder: torch.nn.Module = None,
36
- decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
37
- 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
38
- 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
39
- 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
40
- 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
41
- mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
42
- 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  super().__init__()
44
  self.input_size = input_size
45
  self.output_size = output_size
@@ -74,7 +96,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
74
 
75
  # concat text and prompt_text
76
  mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
77
- token = self.input_embedding(torch.clamp(token, min=0)) * mask
78
 
79
  # text encode
80
  h, h_lengths = self.encoder(token, token_len)
@@ -124,7 +146,13 @@ class MaskedDiffWithXvec(torch.nn.Module):
124
  token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
125
  token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
126
  mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
127
- token = self.input_embedding(torch.clamp(token, min=0)) * mask
 
 
 
 
 
 
128
 
129
  # text encode
130
  h, h_lengths = self.encoder(token, token_len)
@@ -153,25 +181,45 @@ class MaskedDiffWithXvec(torch.nn.Module):
153
 
154
 
155
  class CausalMaskedDiffWithXvec(torch.nn.Module):
156
- def __init__(self,
157
- input_size: int = 512,
158
- output_size: int = 80,
159
- spk_embed_dim: int = 192,
160
- output_type: str = "mel",
161
- vocab_size: int = 6561,
162
- input_frame_rate: int = 25,
163
- only_mask_loss: bool = True,
164
- token_mel_ratio: int = 2,
165
- pre_lookahead_len: int = 3,
166
- encoder: torch.nn.Module = None,
167
- decoder: torch.nn.Module = None,
168
- decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
169
- 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
170
- 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
171
- 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
172
- 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
173
- mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
174
- 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  super().__init__()
176
  self.input_size = input_size
177
  self.output_size = output_size
@@ -215,7 +263,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
215
  # concat text and prompt_text
216
  token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
217
  mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
218
- token = self.input_embedding(torch.clamp(token, min=0)) * mask
219
 
220
  # text encode
221
  h, h_lengths = self.encoder(token, token_len)
 
14
  import logging
15
  import random
16
  from typing import Dict, Optional
17
+
18
+ logger = logging.getLogger(__name__)
19
  import torch
20
  import torch.nn as nn
21
  from torch.nn import functional as F
 
22
  from .utils.mask import make_pad_mask
23
+ from .configs import CFM_PARAMS
24
 
25
 
26
  class MaskedDiffWithXvec(torch.nn.Module):
27
+ def __init__(
28
+ self,
29
+ input_size: int = 512,
30
+ output_size: int = 80,
31
+ spk_embed_dim: int = 192,
32
+ output_type: str = "mel",
33
+ vocab_size: int = 4096,
34
+ input_frame_rate: int = 50,
35
+ only_mask_loss: bool = True,
36
+ encoder: torch.nn.Module = None,
37
+ length_regulator: torch.nn.Module = None,
38
+ decoder: torch.nn.Module = None,
39
+ decoder_conf: Dict = {
40
+ 'in_channels': 240,
41
+ 'out_channel': 80,
42
+ 'spk_emb_dim': 80,
43
+ 'n_spks': 1,
44
+ 'cfm_params': CFM_PARAMS,
45
+ 'decoder_params': {
46
+ 'channels': [256, 256],
47
+ 'dropout': 0.0,
48
+ 'attention_head_dim': 64,
49
+ 'n_blocks': 4,
50
+ 'num_mid_blocks': 12,
51
+ 'num_heads': 8,
52
+ 'act_fn': 'gelu',
53
+ }
54
+ },
55
+ mel_feat_conf: Dict = {
56
+ 'n_fft': 1024,
57
+ 'num_mels': 80,
58
+ 'sampling_rate': 22050,
59
+ 'hop_size': 256,
60
+ 'win_size': 1024,
61
+ 'fmin': 0,
62
+ 'fmax': 8000
63
+ }
64
+ ):
65
  super().__init__()
66
  self.input_size = input_size
67
  self.output_size = output_size
 
96
 
97
  # concat text and prompt_text
98
  mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
99
+ token = self.input_embedding(torch.clamp(token, min=0, max=self.input_embedding.num_embeddings-1)) * mask
100
 
101
  # text encode
102
  h, h_lengths = self.encoder(token, token_len)
 
146
  token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
147
  token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
148
  mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
149
+
150
+ # Check for out-of-bounds token IDs
151
+ vocab_size = self.input_embedding.num_embeddings
152
+ if token.max() >= vocab_size or token.min() < 0:
153
+ logging.warning(f"S3Gen: Token IDs out of bounds: min={token.min().item()}, max={token.max().item()}, vocab_size={vocab_size}")
154
+
155
+ token = self.input_embedding(torch.clamp(token, min=0, max=vocab_size-1)) * mask
156
 
157
  # text encode
158
  h, h_lengths = self.encoder(token, token_len)
 
181
 
182
 
183
  class CausalMaskedDiffWithXvec(torch.nn.Module):
184
+ def __init__(
185
+ self,
186
+ input_size: int = 512,
187
+ output_size: int = 80,
188
+ spk_embed_dim: int = 192,
189
+ output_type: str = "mel",
190
+ vocab_size: int = 6561,
191
+ input_frame_rate: int = 25,
192
+ only_mask_loss: bool = True,
193
+ token_mel_ratio: int = 2,
194
+ pre_lookahead_len: int = 3,
195
+ encoder: torch.nn.Module = None,
196
+ decoder: torch.nn.Module = None,
197
+ decoder_conf: Dict = {
198
+ 'in_channels': 240,
199
+ 'out_channel': 80,
200
+ 'spk_emb_dim': 80,
201
+ 'n_spks': 1,
202
+ 'cfm_params': CFM_PARAMS,
203
+ 'decoder_params': {
204
+ 'channels': [256, 256],
205
+ 'dropout': 0.0,
206
+ 'attention_head_dim': 64,
207
+ 'n_blocks': 4,
208
+ 'num_mid_blocks': 12,
209
+ 'num_heads': 8,
210
+ 'act_fn': 'gelu',
211
+ }
212
+ },
213
+ mel_feat_conf: Dict = {
214
+ 'n_fft': 1024,
215
+ 'num_mels': 80,
216
+ 'sampling_rate': 22050,
217
+ 'hop_size': 256,
218
+ 'win_size': 1024,
219
+ 'fmin': 0,
220
+ 'fmax': 8000
221
+ }
222
+ ):
223
  super().__init__()
224
  self.input_size = input_size
225
  self.output_size = output_size
 
263
  # concat text and prompt_text
264
  token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
265
  mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
266
+ token = self.input_embedding(torch.clamp(token, min=0, max=self.input_embedding.num_embeddings-1)) * mask
267
 
268
  # text encode
269
  h, h_lengths = self.encoder(token, token_len)
chatterbox/src/chatterbox/models/s3gen/flow_matching.py CHANGED
@@ -15,17 +15,7 @@ import threading
15
  import torch
16
  import torch.nn.functional as F
17
  from .matcha.flow_matching import BASECFM
18
- from omegaconf import OmegaConf
19
-
20
-
21
- CFM_PARAMS = OmegaConf.create({
22
- "sigma_min": 1e-06,
23
- "solver": "euler",
24
- "t_scheduler": "cosine",
25
- "training_cfg_rate": 0.2,
26
- "inference_cfg_rate": 0.7,
27
- "reg_loss_type": "l1"
28
- })
29
 
30
 
31
  class ConditionalCFM(BASECFM):
 
15
  import torch
16
  import torch.nn.functional as F
17
  from .matcha.flow_matching import BASECFM
18
+ from .configs import CFM_PARAMS
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  class ConditionalCFM(BASECFM):
chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/decoder.cpython-311.pyc DELETED
Binary file (21.3 kB)
 
chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/flow_matching.cpython-311.pyc DELETED
Binary file (6.46 kB)
 
chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/transformer.cpython-311.pyc DELETED
Binary file (14.7 kB)
 
chatterbox/src/chatterbox/models/s3gen/s3gen.py CHANGED
@@ -19,7 +19,6 @@ import torch
19
  import torchaudio as ta
20
  from functools import lru_cache
21
  from typing import Optional
22
- from omegaconf import DictConfig
23
 
24
  from ..s3tokenizer import S3_SR, SPEECH_VOCAB_SIZE, S3Tokenizer
25
  from .const import S3GEN_SR
@@ -31,6 +30,7 @@ from .hifigan import HiFTGenerator
31
  from .transformer.upsample_encoder import UpsampleConformerEncoder
32
  from .flow_matching import CausalConditionalCFM
33
  from .decoder import ConditionalDecoder
 
34
 
35
 
36
  def drop_invalid_tokens(x):
@@ -85,14 +85,7 @@ class S3Token2Mel(torch.nn.Module):
85
  num_heads=8,
86
  act_fn='gelu',
87
  )
88
- cfm_params = DictConfig({
89
- "sigma_min": 1e-06,
90
- "solver": 'euler',
91
- "t_scheduler": 'cosine',
92
- "training_cfg_rate": 0.2,
93
- "inference_cfg_rate": 0.7,
94
- "reg_loss_type": 'l1',
95
- })
96
  decoder = CausalConditionalCFM(
97
  spk_emb_dim=80,
98
  cfm_params=cfm_params,
 
19
  import torchaudio as ta
20
  from functools import lru_cache
21
  from typing import Optional
 
22
 
23
  from ..s3tokenizer import S3_SR, SPEECH_VOCAB_SIZE, S3Tokenizer
24
  from .const import S3GEN_SR
 
30
  from .transformer.upsample_encoder import UpsampleConformerEncoder
31
  from .flow_matching import CausalConditionalCFM
32
  from .decoder import ConditionalDecoder
33
+ from .configs import CFM_PARAMS
34
 
35
 
36
  def drop_invalid_tokens(x):
 
85
  num_heads=8,
86
  act_fn='gelu',
87
  )
88
+ cfm_params = CFM_PARAMS
 
 
 
 
 
 
 
89
  decoder = CausalConditionalCFM(
90
  spk_emb_dim=80,
91
  cfm_params=cfm_params,
chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (190 Bytes)
 
chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/activation.cpython-311.pyc DELETED
Binary file (3.58 kB)
 
chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/attention.cpython-311.pyc DELETED
Binary file (15.7 kB)
 
chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/convolution.cpython-311.pyc DELETED
Binary file (5.54 kB)
 
chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/embedding.cpython-311.pyc DELETED
Binary file (17.3 kB)
 
chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/encoder_layer.cpython-311.pyc DELETED
Binary file (11.2 kB)
 
chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc DELETED
Binary file (6.24 kB)
 
chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/subsampling.cpython-311.pyc DELETED
Binary file (18.9 kB)
 
chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/upsample_encoder.cpython-311.pyc DELETED
Binary file (15.6 kB)
 
chatterbox/src/chatterbox/models/s3gen/utils/__pycache__/class_utils.cpython-311.pyc DELETED
Binary file (1.93 kB)
 
chatterbox/src/chatterbox/models/s3gen/utils/__pycache__/mask.cpython-311.pyc DELETED
Binary file (6.25 kB)
 
chatterbox/src/chatterbox/models/s3gen/utils/__pycache__/mel.cpython-311.pyc DELETED
Binary file (4.05 kB)
 
chatterbox/src/chatterbox/models/s3gen/utils/mel.py CHANGED
@@ -1,8 +1,11 @@
1
  """mel-spectrogram extraction in Matcha-TTS"""
 
2
  from librosa.filters import mel as librosa_mel_fn
3
  import torch
4
  import numpy as np
5
 
 
 
6
 
7
  # NOTE: they decalred these global vars
8
  mel_basis = {}
@@ -42,10 +45,11 @@ def mel_spectrogram(y, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=48
42
  if len(y.shape) == 1:
43
  y = y[None, ]
44
 
45
- if torch.min(y) < -1.0:
46
- print("min value is ", torch.min(y))
47
- if torch.max(y) > 1.0:
48
- print("max value is ", torch.max(y))
 
49
 
50
  global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
51
  if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
 
1
  """mel-spectrogram extraction in Matcha-TTS"""
2
+ import logging
3
  from librosa.filters import mel as librosa_mel_fn
4
  import torch
5
  import numpy as np
6
 
7
+ logger = logging.getLogger(__name__)
8
+
9
 
10
  # NOTE: they decalred these global vars
11
  mel_basis = {}
 
45
  if len(y.shape) == 1:
46
  y = y[None, ]
47
 
48
+ # Debug: Check for audio clipping (values outside [-1.0, 1.0] range)
49
+ min_val = torch.min(y)
50
+ max_val = torch.max(y)
51
+ if min_val < -1.0 or max_val > 1.0:
52
+ logger.warning(f"Audio values outside normalized range: min={min_val.item():.4f}, max={max_val.item():.4f}")
53
 
54
  global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
55
  if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
chatterbox/src/chatterbox/models/s3tokenizer/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (1.37 kB)
 
chatterbox/src/chatterbox/models/s3tokenizer/__pycache__/s3tokenizer.cpython-311.pyc DELETED
Binary file (7.94 kB)
 
chatterbox/src/chatterbox/models/t3/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (218 Bytes)
 
chatterbox/src/chatterbox/models/t3/__pycache__/llama_configs.cpython-311.pyc DELETED
Binary file (1.34 kB)
 
chatterbox/src/chatterbox/models/t3/__pycache__/t3.cpython-311.pyc DELETED
Binary file (15.8 kB)
 
chatterbox/src/chatterbox/models/t3/inference/__pycache__/alignment_stream_analyzer.cpython-311.pyc DELETED
Binary file (7.08 kB)
 
chatterbox/src/chatterbox/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc DELETED
Binary file (4.65 kB)
 
chatterbox/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py CHANGED
@@ -10,6 +10,9 @@ from types import MethodType
10
  logger = logging.getLogger(__name__)
11
 
12
 
 
 
 
13
  @dataclass
14
  class AlignmentAnalysisResult:
15
  # was this frame detected as being part of a noisy beginning chunk with potential hallucinations?
@@ -49,21 +52,22 @@ class AlignmentStreamAnalyzer:
49
 
50
  self.complete = False
51
  self.completed_at = None
 
 
 
52
 
53
  # Using `output_attentions=True` is incompatible with optimized attention kernels, so
54
  # using it for all layers slows things down too much. We can apply it to just one layer
55
  # by intercepting the kwargs and adding a forward hook (credit: jrm)
56
- self.last_aligned_attn = None
57
- self._add_attention_spy(tfmr, alignment_layer_idx)
 
 
58
 
59
- def _add_attention_spy(self, tfmr, alignment_layer_idx):
60
  """
61
  Adds a forward hook to a specific attention layer to collect outputs.
62
- Using `output_attentions=True` is incompatible with optimized attention kernels, so
63
- using it for all layers slows things down too much.
64
- (credit: jrm)
65
  """
66
-
67
  def attention_forward_hook(module, input, output):
68
  """
69
  See `LlamaAttention.forward`; the output is a 3-tuple: `attn_output, attn_weights, past_key_value`.
@@ -71,27 +75,23 @@ class AlignmentStreamAnalyzer:
71
  - When `output_attentions=True`, `LlamaSdpaAttention.forward` calls `LlamaAttention.forward`.
72
  - `attn_output` has shape [B, H, T0, T0] for the 0th entry, and [B, H, 1, T0+i] for the rest i-th.
73
  """
74
- step_attention = output[1].cpu() # (B, 16, N, N)
75
- self.last_aligned_attn = step_attention[0].mean(0) # (N, N)
76
-
77
- target_layer = tfmr.layers[alignment_layer_idx].self_attn
78
- hook_handle = target_layer.register_forward_hook(attention_forward_hook)
79
-
80
- # Backup original forward
81
- original_forward = target_layer.forward
82
- def patched_forward(self, *args, **kwargs):
83
- kwargs['output_attentions'] = True
84
- return original_forward(*args, **kwargs)
85
-
86
- # TODO: how to unpatch it?
87
- target_layer.forward = MethodType(patched_forward, target_layer)
88
-
89
- def step(self, logits):
90
  """
91
  Emits an AlignmentAnalysisResult into the output queue, and potentially modifies the logits to force an EOS.
92
  """
93
  # extract approximate alignment matrix chunk (1 frame at a time after the first chunk)
94
- aligned_attn = self.last_aligned_attn # (N, N)
95
  i, j = self.text_tokens_slice
96
  if self.curr_frame_pos == 0:
97
  # first chunk has conditioning info, text tokens, and BOS token
@@ -133,22 +133,46 @@ class AlignmentStreamAnalyzer:
133
  last_text_token_duration = A[15:, -3:].sum()
134
 
135
  # Activations for the final token that last too long are likely hallucinations.
136
- long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 10) # 400ms
137
 
138
  # If there are activations in previous tokens after generation has completed, assume this is a repetition error.
139
- repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  # If a bad ending is detected, force emit EOS by modifying logits
142
  # NOTE: this means logits may be inconsistent with latents!
143
- if long_tail or repetition:
144
- logger.warn(f"forcing EOS token, {long_tail=}, {repetition=}")
145
  # (±2**15 is safe for all dtypes >= 16bit)
146
  logits = -(2**15) * torch.ones_like(logits)
147
  logits[..., self.eos_idx] = 2**15
148
 
149
- # Suppress EoS to prevent early termination
150
- if cur_text_posn < S - 3: # FIXME: arbitrary
151
- logits[..., self.eos_idx] = -2**15
152
-
153
  self.curr_frame_pos += 1
154
  return logits
 
10
  logger = logging.getLogger(__name__)
11
 
12
 
13
+ LLAMA_ALIGNED_HEADS = [(12, 15), (13, 11), (9, 2)]
14
+
15
+
16
  @dataclass
17
  class AlignmentAnalysisResult:
18
  # was this frame detected as being part of a noisy beginning chunk with potential hallucinations?
 
52
 
53
  self.complete = False
54
  self.completed_at = None
55
+
56
+ # Track generated tokens for repetition detection
57
+ self.generated_tokens = []
58
 
59
  # Using `output_attentions=True` is incompatible with optimized attention kernels, so
60
  # using it for all layers slows things down too much. We can apply it to just one layer
61
  # by intercepting the kwargs and adding a forward hook (credit: jrm)
62
+ self.last_aligned_attns = []
63
+ for i, (layer_idx, head_idx) in enumerate(LLAMA_ALIGNED_HEADS):
64
+ self.last_aligned_attns += [None]
65
+ self._add_attention_spy(tfmr, i, layer_idx, head_idx)
66
 
67
+ def _add_attention_spy(self, tfmr, buffer_idx, layer_idx, head_idx):
68
  """
69
  Adds a forward hook to a specific attention layer to collect outputs.
 
 
 
70
  """
 
71
  def attention_forward_hook(module, input, output):
72
  """
73
  See `LlamaAttention.forward`; the output is a 3-tuple: `attn_output, attn_weights, past_key_value`.
 
75
  - When `output_attentions=True`, `LlamaSdpaAttention.forward` calls `LlamaAttention.forward`.
76
  - `attn_output` has shape [B, H, T0, T0] for the 0th entry, and [B, H, 1, T0+i] for the rest i-th.
77
  """
78
+ if isinstance(output, tuple) and len(output) > 1 and output[1] is not None:
79
+ step_attention = output[1].cpu() # (B, n_heads, T0, Ti)
80
+ self.last_aligned_attns[buffer_idx] = step_attention[0, head_idx] # (T0, Ti)
81
+
82
+ target_layer = tfmr.layers[layer_idx].self_attn
83
+ # Register hook and store the handle
84
+ target_layer.register_forward_hook(attention_forward_hook)
85
+ if hasattr(tfmr, 'config') and hasattr(tfmr.config, 'output_attentions'):
86
+ self.original_output_attentions = tfmr.config.output_attentions
87
+ tfmr.config.output_attentions = True
88
+
89
+ def step(self, logits, next_token=None):
 
 
 
 
90
  """
91
  Emits an AlignmentAnalysisResult into the output queue, and potentially modifies the logits to force an EOS.
92
  """
93
  # extract approximate alignment matrix chunk (1 frame at a time after the first chunk)
94
+ aligned_attn = torch.stack(self.last_aligned_attns).mean(dim=0) # (N, N)
95
  i, j = self.text_tokens_slice
96
  if self.curr_frame_pos == 0:
97
  # first chunk has conditioning info, text tokens, and BOS token
 
133
  last_text_token_duration = A[15:, -3:].sum()
134
 
135
  # Activations for the final token that last too long are likely hallucinations.
136
+ long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 5) # 200ms
137
 
138
  # If there are activations in previous tokens after generation has completed, assume this is a repetition error.
139
+ alignment_repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5)
140
+
141
+ # Track generated tokens for repetition detection
142
+ if next_token is not None:
143
+ # Convert tensor to scalar if needed
144
+ if isinstance(next_token, torch.Tensor):
145
+ token_id = next_token.item() if next_token.numel() == 1 else next_token.view(-1)[0].item()
146
+ else:
147
+ token_id = next_token
148
+ self.generated_tokens.append(token_id)
149
+
150
+ # Keep only last 8 tokens to prevent memory issues
151
+ if len(self.generated_tokens) > 8:
152
+ self.generated_tokens = self.generated_tokens[-8:]
153
+
154
+ # Check for excessive token repetition (3x same token in a row)
155
+ token_repetition = (
156
+ # self.complete and
157
+ len(self.generated_tokens) >= 3 and
158
+ len(set(self.generated_tokens[-2:])) == 1
159
+ )
160
+
161
+ if token_repetition:
162
+ repeated_token = self.generated_tokens[-1]
163
+ logger.warning(f"🚨 Detected 2x repetition of token {repeated_token}")
164
+
165
+ # Suppress EoS to prevent early termination
166
+ if cur_text_posn < S - 3 and S > 5: # Only suppress if text is longer than 5 tokens
167
+ logits[..., self.eos_idx] = -2**15
168
 
169
  # If a bad ending is detected, force emit EOS by modifying logits
170
  # NOTE: this means logits may be inconsistent with latents!
171
+ if long_tail or alignment_repetition or token_repetition:
172
+ logger.warning(f"forcing EOS token, {long_tail=}, {alignment_repetition=}, {token_repetition=}")
173
  # (±2**15 is safe for all dtypes >= 16bit)
174
  logits = -(2**15) * torch.ones_like(logits)
175
  logits[..., self.eos_idx] = 2**15
176
 
 
 
 
 
177
  self.curr_frame_pos += 1
178
  return logits
chatterbox/src/chatterbox/models/t3/modules/__pycache__/cond_enc.cpython-311.pyc DELETED
Binary file (5.37 kB)
 
chatterbox/src/chatterbox/models/t3/modules/__pycache__/learned_pos_emb.cpython-311.pyc DELETED
Binary file (2.54 kB)
 
chatterbox/src/chatterbox/models/t3/modules/__pycache__/perceiver.cpython-311.pyc DELETED
Binary file (12.6 kB)
 
chatterbox/src/chatterbox/models/t3/modules/__pycache__/t3_config.cpython-311.pyc DELETED
Binary file (1.27 kB)
 
chatterbox/src/chatterbox/models/t3/modules/t3_config.py CHANGED
@@ -2,26 +2,40 @@ from ..llama_configs import LLAMA_CONFIGS
2
 
3
 
4
  class T3Config:
5
- start_text_token = 255
6
- stop_text_token = 0
7
- text_tokens_dict_size = 704
8
- max_text_tokens = 2048
 
9
 
10
- start_speech_token = 6561
11
- stop_speech_token = 6562
12
- speech_tokens_dict_size = 8194
13
- max_speech_tokens = 4096
14
 
15
- llama_config_name = "Llama_520M"
16
- input_pos_emb = "learned"
17
- speech_cond_prompt_len = 150
18
 
19
- # For T3CondEnc
20
- encoder_type = "voice_encoder"
21
- speaker_embed_size = 256
22
- use_perceiver_resampler = True
23
- emotion_adv = True
24
 
25
  @property
26
  def n_channels(self):
27
  return LLAMA_CONFIGS[self.llama_config_name]["hidden_size"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
 
4
  class T3Config:
5
+ def __init__(self, text_tokens_dict_size=704):
6
+ self.start_text_token = 255
7
+ self.stop_text_token = 0
8
+ self.text_tokens_dict_size = text_tokens_dict_size
9
+ self.max_text_tokens = 2048
10
 
11
+ self.start_speech_token = 6561
12
+ self.stop_speech_token = 6562
13
+ self.speech_tokens_dict_size = 8194
14
+ self.max_speech_tokens = 4096
15
 
16
+ self.llama_config_name = "Llama_520M"
17
+ self.input_pos_emb = "learned"
18
+ self.speech_cond_prompt_len = 150
19
 
20
+ self.encoder_type = "voice_encoder"
21
+ self.speaker_embed_size = 256
22
+ self.use_perceiver_resampler = True
23
+ self.emotion_adv = True
 
24
 
25
  @property
26
  def n_channels(self):
27
  return LLAMA_CONFIGS[self.llama_config_name]["hidden_size"]
28
+
29
+ @property
30
+ def is_multilingual(self):
31
+ return self.text_tokens_dict_size == 2454
32
+
33
+ @classmethod
34
+ def english_only(cls):
35
+ """Create configuration for English-only TTS model."""
36
+ return cls(text_tokens_dict_size=704)
37
+
38
+ @classmethod
39
+ def multilingual(cls):
40
+ """Create configuration for multilingual TTS model."""
41
+ return cls(text_tokens_dict_size=2454)
chatterbox/src/chatterbox/models/t3/t3.py CHANGED
@@ -3,12 +3,14 @@
3
  import logging
4
  from typing import Union, Optional, List
5
 
 
 
6
  from tqdm import tqdm
7
  import torch
8
  import torch.nn.functional as F
9
  from torch import nn, Tensor
10
  from transformers import LlamaModel, LlamaConfig
11
- from transformers.generation.logits_process import TopPLogitsWarper, RepetitionPenaltyLogitsProcessor
12
 
13
  from .modules.learned_pos_emb import LearnedPositionEmbeddings
14
 
@@ -17,17 +19,12 @@ from .modules.t3_config import T3Config
17
  from .llama_configs import LLAMA_CONFIGS
18
  from .inference.t3_hf_backend import T3HuggingfaceBackend
19
  from .inference.alignment_stream_analyzer import AlignmentStreamAnalyzer
 
20
 
21
 
22
  logger = logging.getLogger(__name__)
23
 
24
 
25
- class AttrDict(dict):
26
- def __init__(self, *args, **kwargs):
27
- super(AttrDict, self).__init__(*args, **kwargs)
28
- self.__dict__ = self
29
-
30
-
31
  def _ensure_BOT_EOT(text_tokens: Tensor, hp):
32
  B = text_tokens.size(0)
33
  assert (text_tokens == hp.start_text_token).int().sum() >= B, "missing start_text_token"
@@ -44,7 +41,9 @@ class T3(nn.Module):
44
  different PE embedding space for speech.
45
  """
46
 
47
- def __init__(self, hp=T3Config()):
 
 
48
  super().__init__()
49
  self.hp = hp
50
  self.cfg = LlamaConfig(**LLAMA_CONFIGS[hp.llama_config_name])
@@ -89,11 +88,13 @@ class T3(nn.Module):
89
  t3_cond: T3Cond,
90
  text_tokens: torch.LongTensor,
91
  speech_tokens: torch.LongTensor,
 
92
  ):
93
  # prepare input embeddings (skip backbone tranformer embeddings)
94
  cond_emb = self.prepare_conditioning(t3_cond) # (B, len_cond, dim)
95
  text_emb = self.text_emb(text_tokens) # (B, len_text, dim)
96
- text_emb[1].zero_() # CFG uncond
 
97
 
98
  speech_emb = self.speech_emb(speech_tokens) # (B, len_speech, dim)
99
  if self.hp.input_pos_emb == "learned":
@@ -221,10 +222,11 @@ class T3(nn.Module):
221
  stop_on_eos=True,
222
  do_sample=True,
223
  temperature=0.8,
224
- top_p=0.8,
 
225
  length_penalty=1.0,
226
- repetition_penalty=2.0,
227
- cfg_weight=0,
228
  ):
229
  """
230
  Args:
@@ -244,6 +246,7 @@ class T3(nn.Module):
244
  t3_cond=t3_cond,
245
  text_tokens=text_tokens,
246
  speech_tokens=initial_speech_tokens,
 
247
  )
248
 
249
  # In order to use the standard HF generate method, we need to extend some methods to inject our custom logic
@@ -254,19 +257,24 @@ class T3(nn.Module):
254
  # TODO? synchronize the expensive compile function
255
  # with self.compile_lock:
256
  if not self.compiled:
257
- # alignment_stream_analyzer = AlignmentStreamAnalyzer(
258
- # self.tfmr,
259
- # None,
260
- # text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
261
- # alignment_layer_idx=9, # TODO: hparam or something?
262
- # eos_idx=self.hp.stop_speech_token,
263
- # )
 
 
 
 
 
264
  patched_model = T3HuggingfaceBackend(
265
  config=self.cfg,
266
  llama=self.tfmr,
267
  speech_enc=self.speech_emb,
268
  speech_head=self.speech_head,
269
- # alignment_stream_analyzer=alignment_stream_analyzer,
270
  )
271
  self.patched_model = patched_model
272
  self.compiled = True
@@ -281,7 +289,7 @@ class T3(nn.Module):
281
  # max_new_tokens=max_new_tokens or self.hp.max_speech_tokens,
282
  # num_return_sequences=num_return_sequences,
283
  # temperature=temperature,
284
- # top_p=top_p,
285
  # length_penalty=length_penalty,
286
  # repetition_penalty=repetition_penalty,
287
  # do_sample=do_sample,
@@ -306,7 +314,9 @@ class T3(nn.Module):
306
 
307
  # Instantiate the logits processors.
308
  top_p_warper = TopPLogitsWarper(top_p=top_p)
309
- repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
 
 
310
 
311
  # ---- Initial Forward Pass (no kv_cache yet) ----
312
  output = self.patched_model(
@@ -322,21 +332,32 @@ class T3(nn.Module):
322
 
323
  # ---- Generation Loop using kv_cache ----
324
  for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True):
325
- logits = output.logits[:, -1, :]
326
-
327
- # CFG
328
- logits_cond = logits[0:1]
329
- logits_uncond = logits[1:2]
330
- logits = logits_cond + cfg_weight * (logits_cond - logits_uncond)
331
- logits = logits.squeeze(1)
332
-
 
 
 
 
 
 
 
 
 
 
 
333
  # Apply temperature scaling.
334
  if temperature != 1.0:
335
  logits = logits / temperature
336
-
337
- # Apply repetition penalty and top‑p filtering.
338
- logits = repetition_penalty_processor(generated_ids, logits)
339
- logits = top_p_warper(None, logits)
340
 
341
  # Convert logits to probabilities and sample the next token.
342
  probs = torch.softmax(logits, dim=-1)
@@ -347,6 +368,7 @@ class T3(nn.Module):
347
 
348
  # Check for EOS token.
349
  if next_token.view(-1) == self.hp.stop_speech_token:
 
350
  break
351
 
352
  # Get embedding for the new token.
 
3
  import logging
4
  from typing import Union, Optional, List
5
 
6
+ logger = logging.getLogger(__name__)
7
+
8
  from tqdm import tqdm
9
  import torch
10
  import torch.nn.functional as F
11
  from torch import nn, Tensor
12
  from transformers import LlamaModel, LlamaConfig
13
+ from transformers.generation.logits_process import TopPLogitsWarper, RepetitionPenaltyLogitsProcessor, MinPLogitsWarper
14
 
15
  from .modules.learned_pos_emb import LearnedPositionEmbeddings
16
 
 
19
  from .llama_configs import LLAMA_CONFIGS
20
  from .inference.t3_hf_backend import T3HuggingfaceBackend
21
  from .inference.alignment_stream_analyzer import AlignmentStreamAnalyzer
22
+ from ..utils import AttrDict
23
 
24
 
25
  logger = logging.getLogger(__name__)
26
 
27
 
 
 
 
 
 
 
28
  def _ensure_BOT_EOT(text_tokens: Tensor, hp):
29
  B = text_tokens.size(0)
30
  assert (text_tokens == hp.start_text_token).int().sum() >= B, "missing start_text_token"
 
41
  different PE embedding space for speech.
42
  """
43
 
44
+ def __init__(self, hp=None):
45
+ if hp is None:
46
+ hp = T3Config.english_only() # Default to English-only config for backward compatibility
47
  super().__init__()
48
  self.hp = hp
49
  self.cfg = LlamaConfig(**LLAMA_CONFIGS[hp.llama_config_name])
 
88
  t3_cond: T3Cond,
89
  text_tokens: torch.LongTensor,
90
  speech_tokens: torch.LongTensor,
91
+ cfg_weight: float = 0.0,
92
  ):
93
  # prepare input embeddings (skip backbone tranformer embeddings)
94
  cond_emb = self.prepare_conditioning(t3_cond) # (B, len_cond, dim)
95
  text_emb = self.text_emb(text_tokens) # (B, len_text, dim)
96
+ if cfg_weight > 0.0:
97
+ text_emb[1].zero_() # CFG uncond
98
 
99
  speech_emb = self.speech_emb(speech_tokens) # (B, len_speech, dim)
100
  if self.hp.input_pos_emb == "learned":
 
222
  stop_on_eos=True,
223
  do_sample=True,
224
  temperature=0.8,
225
+ top_p=0.95,
226
+ min_p=0.05,
227
  length_penalty=1.0,
228
+ repetition_penalty=1.2,
229
+ cfg_weight=0.5,
230
  ):
231
  """
232
  Args:
 
246
  t3_cond=t3_cond,
247
  text_tokens=text_tokens,
248
  speech_tokens=initial_speech_tokens,
249
+ cfg_weight=cfg_weight,
250
  )
251
 
252
  # In order to use the standard HF generate method, we need to extend some methods to inject our custom logic
 
257
  # TODO? synchronize the expensive compile function
258
  # with self.compile_lock:
259
  if not self.compiled:
260
+ # Default to None for English models, only create for multilingual
261
+ alignment_stream_analyzer = None
262
+ if self.hp.is_multilingual:
263
+ alignment_stream_analyzer = AlignmentStreamAnalyzer(
264
+ self.tfmr,
265
+ None,
266
+ text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
267
+ alignment_layer_idx=9, # TODO: hparam or something?
268
+ eos_idx=self.hp.stop_speech_token,
269
+ )
270
+ assert alignment_stream_analyzer.eos_idx == self.hp.stop_speech_token
271
+
272
  patched_model = T3HuggingfaceBackend(
273
  config=self.cfg,
274
  llama=self.tfmr,
275
  speech_enc=self.speech_emb,
276
  speech_head=self.speech_head,
277
+ alignment_stream_analyzer=alignment_stream_analyzer,
278
  )
279
  self.patched_model = patched_model
280
  self.compiled = True
 
289
  # max_new_tokens=max_new_tokens or self.hp.max_speech_tokens,
290
  # num_return_sequences=num_return_sequences,
291
  # temperature=temperature,
292
+ # min_p=min_p,
293
  # length_penalty=length_penalty,
294
  # repetition_penalty=repetition_penalty,
295
  # do_sample=do_sample,
 
314
 
315
  # Instantiate the logits processors.
316
  top_p_warper = TopPLogitsWarper(top_p=top_p)
317
+ min_p_warper = MinPLogitsWarper(min_p=min_p)
318
+ top_p_warper = TopPLogitsWarper(top_p=top_p)
319
+ repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty))
320
 
321
  # ---- Initial Forward Pass (no kv_cache yet) ----
322
  output = self.patched_model(
 
332
 
333
  # ---- Generation Loop using kv_cache ----
334
  for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True):
335
+ logits_step = output.logits[:, -1, :]
336
+ # CFG combine → (1, V)
337
+ cond = logits_step[0:1, :]
338
+ uncond = logits_step[1:2, :]
339
+ cfg = torch.as_tensor(cfg_weight, device=cond.device, dtype=cond.dtype)
340
+ logits = cond + cfg * (cond - uncond)
341
+
342
+ # Apply alignment stream analyzer integrity checks
343
+ if self.patched_model.alignment_stream_analyzer is not None:
344
+ if logits.dim() == 1: # guard in case something upstream squeezed
345
+ logits = logits.unsqueeze(0) # (1, V)
346
+ # Pass the last generated token for repetition tracking
347
+ last_token = generated_ids[0, -1].item() if len(generated_ids[0]) > 0 else None
348
+ logits = self.patched_model.alignment_stream_analyzer.step(logits, next_token=last_token) # (1, V)
349
+
350
+ # Apply repetition penalty
351
+ ids_for_proc = generated_ids[:1, ...] # batch = 1
352
+ logits = repetition_penalty_processor(ids_for_proc, logits) # expects (B,V)
353
+
354
  # Apply temperature scaling.
355
  if temperature != 1.0:
356
  logits = logits / temperature
357
+
358
+ # Apply min_p and top_p filtering
359
+ logits = min_p_warper(ids_for_proc, logits)
360
+ logits = top_p_warper(ids_for_proc, logits)
361
 
362
  # Convert logits to probabilities and sample the next token.
363
  probs = torch.softmax(logits, dim=-1)
 
368
 
369
  # Check for EOS token.
370
  if next_token.view(-1) == self.hp.stop_speech_token:
371
+ logger.info(f"✅ EOS token detected! Stopping generation at step {i+1}")
372
  break
373
 
374
  # Get embedding for the new token.