File size: 10,829 Bytes
f6c1e7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# coding=utf-8
# Copyright 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved.
# ... (License headers remain the same) ...
""" SparkTTS model configuration"""

from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from typing import List, Optional # Added typing


logger = logging.get_logger(__name__)

# --- Define Individual Sub-Component Config Classes ---

class SparkTTSMelParamsConfig(PretrainedConfig):
    """Configuration for Mel Spectrogram parameters."""
    model_type = "spark-tts-mel-params"
    def __init__(self, sample_rate=16000, n_fft=1024, win_length=640, hop_length=320,
                 mel_fmin=10, mel_fmax=None, num_mels=128, **kwargs):
        super().__init__(**kwargs)
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.win_length = win_length
        self.hop_length = hop_length
        self.mel_fmin = mel_fmin
        self.mel_fmax = mel_fmax
        self.num_mels = num_mels

class SparkTTSEncoderConfig(PretrainedConfig):
    """Configuration for the BiCodec Feature Encoder."""
    model_type = "spark-tts-encoder"
    def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
                 vocos_num_layers=12, out_channels=1024, sample_ratios=[1, 1], **kwargs):
        super().__init__(**kwargs)
        self.input_channels = input_channels
        self.vocos_dim = vocos_dim
        self.vocos_intermediate_dim = vocos_intermediate_dim
        self.vocos_num_layers = vocos_num_layers
        self.out_channels = out_channels
        self.sample_ratios = sample_ratios

class SparkTTSDecoderConfig(PretrainedConfig):
    """Configuration for the BiCodec Wave Generator (Decoder)."""
    model_type = "spark-tts-decoder"
    def __init__(self, input_channel=1024, channels=1536, rates=[8, 5, 4, 2],
                 kernel_sizes=[16, 11, 8, 4], **kwargs):
        super().__init__(**kwargs)
        self.input_channel = input_channel
        self.channels = channels
        self.rates = rates
        self.kernel_sizes = kernel_sizes

class SparkTTSQuantizerConfig(PretrainedConfig):
    """Configuration for the BiCodec Factorized Vector Quantizer."""
    model_type = "spark-tts-quantizer"
    def __init__(self, input_dim=1024, codebook_size=8192, codebook_dim=8,
                 commitment=0.25, codebook_loss_weight=2.0, decay=0.99,
                 threshold_ema_dead_code=0.2, **kwargs):
        # Note: Removed use_l2_normlize as it wasn't in the original class __init__ args
        # Add it back if it's actually used by the FactorizedVectorQuantize class init
        super().__init__(**kwargs)
        self.input_dim = input_dim
        self.codebook_size = codebook_size
        self.codebook_dim = codebook_dim
        self.commitment = commitment
        self.codebook_loss_weight = codebook_loss_weight
        self.decay = decay
        self.threshold_ema_dead_code = threshold_ema_dead_code

class SparkTTSSpeakerEncoderConfig(PretrainedConfig):
    """Configuration for the BiCodec Speaker Encoder."""
    model_type = "spark-tts-speaker-encoder"
    def __init__(self, input_dim=128, out_dim=1024, latent_dim=128, token_num=32,
                 fsq_levels=[4, 4, 4, 4, 4, 4], fsq_num_quantizers=1, **kwargs):
        super().__init__(**kwargs)
        self.input_dim = input_dim
        self.out_dim = out_dim
        self.latent_dim = latent_dim
        self.token_num = token_num
        self.fsq_levels = fsq_levels
        self.fsq_num_quantizers = fsq_num_quantizers

class SparkTTSPrenetConfig(PretrainedConfig):
    """Configuration for the BiCodec Prenet."""
    model_type = "spark-tts-prenet"
    def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
                 vocos_num_layers=12, out_channels=1024, condition_dim=1024,
                 sample_ratios=[1, 1], use_tanh_at_final=False, **kwargs):
        super().__init__(**kwargs)
        self.input_channels = input_channels
        self.vocos_dim = vocos_dim
        self.vocos_intermediate_dim = vocos_intermediate_dim
        self.vocos_num_layers = vocos_num_layers
        self.out_channels = out_channels
        self.condition_dim = condition_dim
        self.sample_ratios = sample_ratios
        self.use_tanh_at_final = use_tanh_at_final

class SparkTTSPostnetConfig(PretrainedConfig):
    """Configuration for the BiCodec Postnet."""
    model_type = "spark-tts-postnet"
    def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
                 vocos_num_layers=6, out_channels=1024, use_tanh_at_final=False, **kwargs):
        # Note: Removed condition_dim as it wasn't in the original config example for postnet
        super().__init__(**kwargs)
        self.input_channels = input_channels
        self.vocos_dim = vocos_dim
        self.vocos_intermediate_dim = vocos_intermediate_dim
        self.vocos_num_layers = vocos_num_layers
        self.out_channels = out_channels
        self.use_tanh_at_final = use_tanh_at_final


# --- Define the Intermediate BiCodec Config Class ---

class SparkTTSBiCodecConfig(PretrainedConfig):
    """
    Intermediate configuration class for the BiCodec component within SparkTTS.
    It holds instances of the individual sub-component configurations.
    """
    model_type = "spark-tts-bicodec"
    # Map keys in the 'bicodec_config' dict to their respective classes
    sub_configs = {
        "mel_params": SparkTTSMelParamsConfig,
        "encoder_config": SparkTTSEncoderConfig,
        "decoder_config": SparkTTSDecoderConfig,
        "quantizer_config": SparkTTSQuantizerConfig,
        "speaker_encoder_config": SparkTTSSpeakerEncoderConfig,
        "prenet_config": SparkTTSPrenetConfig,
        "postnet_config": SparkTTSPostnetConfig,
    }

    def __init__(
        self,
        mel_params=None,
        encoder_config=None,
        decoder_config=None,
        quantizer_config=None,
        speaker_encoder_config=None,
        prenet_config=None,
        postnet_config=None,
        **kwargs,
    ):
        super().__init__(**kwargs)

        # Instantiate sub-configs from dictionaries or use defaults/provided instances
        self.mel_params = self._init_sub_config(mel_params, "mel_params")
        self.encoder_config = self._init_sub_config(encoder_config, "encoder_config")
        self.decoder_config = self._init_sub_config(decoder_config, "decoder_config")
        self.quantizer_config = self._init_sub_config(quantizer_config, "quantizer_config")
        self.speaker_encoder_config = self._init_sub_config(speaker_encoder_config, "speaker_encoder_config")
        self.prenet_config = self._init_sub_config(prenet_config, "prenet_config")
        self.postnet_config = self._init_sub_config(postnet_config, "postnet_config")

    def _init_sub_config(self, config_input, config_key):
        """Helper to initialize sub-configs."""
        config_cls = self.sub_configs[config_key]
        if isinstance(config_input, dict):
            return config_cls(**config_input)
        elif config_input is None:
            return config_cls() # Initialize with defaults
        elif isinstance(config_input, config_cls):
            return config_input # Already an instance
        else:
            raise TypeError(f"Invalid type for {config_key}: {type(config_input)}. Expected dict, None, or {config_cls.__name__}.")


# --- Define the Main SparkTTS Config Class ---

class SparkTTSConfig(PretrainedConfig):
    r"""
    Main configuration class for SparkTTSModel, including nested BiCodec configuration.
    Args:
        llm_model_name_or_path (`str`, *optional*, defaults to `"./LLM"`): Path/ID for LLM.
        bicodec_model_name_or_path (`str`, *optional*, defaults to `"./BiCodec"`): Path/ID for BiCodec checkpoint.
        wav2vec2_model_name_or_path (`str`, *optional*, defaults to `"./wav2vec2-large-xlsr-53"`): Path/ID for Wav2Vec2.
        sample_rate (`int`, *optional*, defaults to 16000): Audio sample rate.
        # ... (other top-level args: highpass_cutoff_freq, latent_hop_length, ref_segment_duration, volume_normalize) ...
        bicodec_config (`dict`, *optional*): Dictionary to initialize `SparkTTSBiCodecConfig`.
        torch_dtype (`str`, *optional*, defaults to `"auto"`): Torch dtype.
        kwargs (*optional*): Dictionary of keyword arguments.
    """
    model_type = "spark-tts"
    # Map the key in config.json to the intermediate BiCodec config class
    sub_configs = {"bicodec_config": SparkTTSBiCodecConfig}
    attribute_map = {"hidden_size": "d_model"} # Example

    def __init__(
        self,
        llm_model_name_or_path="./LLM",
        bicodec_model_name_or_path="./BiCodec",
        wav2vec2_model_name_or_path="./wav2vec2-large-xlsr-53",
        sample_rate=16000,
        highpass_cutoff_freq=40,
        latent_hop_length=320,
        ref_segment_duration=6.0,
        volume_normalize=True,
        bicodec_config=None, # Expects a dictionary or None
        torch_dtype="auto",
        **kwargs,
    ):
        # --- Top-level parameters ---
        self.llm_model_name_or_path = llm_model_name_or_path
        self.bicodec_model_name_or_path = bicodec_model_name_or_path
        self.wav2vec2_model_name_or_path = wav2vec2_model_name_or_path
        self.sample_rate = sample_rate
        self.highpass_cutoff_freq = highpass_cutoff_freq
        self.latent_hop_length = latent_hop_length
        self.ref_segment_duration = ref_segment_duration
        self.volume_normalize = volume_normalize
        self.torch_dtype = torch_dtype

        # --- Nested BiCodec Configuration ---
        # Instantiate the intermediate BiCodec config class, which will handle its own sub-configs
        if isinstance(bicodec_config, dict):
            self.bicodec_config = self.sub_configs["bicodec_config"](**bicodec_config)
        elif bicodec_config is None:
            logger.info("`bicodec_config` not provided. Initializing `SparkTTSBiCodecConfig` with its defaults.")
            self.bicodec_config = self.sub_configs["bicodec_config"]()
        elif isinstance(bicodec_config, self.sub_configs["bicodec_config"]):
             self.bicodec_config = bicodec_config # Use existing instance
        else:
             raise TypeError(f"Invalid type for bicodec_config: {type(bicodec_config)}. Expected dict, None, or SparkTTSBiCodecConfig.")


        # Set processor class and auto_map
        kwargs["processor_class"] = kwargs.get("processor_class", "SparkTTSProcessor")
        kwargs["auto_map"] = kwargs.get("auto_map", {
              "AutoConfig": "configuration_spark_tts.SparkTTSConfig",
              "AutoModel": "modeling_spark_tts.SparkTTSModel",
              "AutoProcessor": "processing_spark_tts.SparkTTSProcessor"
            })
        super().__init__(**kwargs)