File size: 10,829 Bytes
f6c1e7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
# coding=utf-8
# Copyright 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved.
# ... (License headers remain the same) ...
""" SparkTTS model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from typing import List, Optional # Added typing
logger = logging.get_logger(__name__)
# --- Define Individual Sub-Component Config Classes ---
class SparkTTSMelParamsConfig(PretrainedConfig):
"""Configuration for Mel Spectrogram parameters."""
model_type = "spark-tts-mel-params"
def __init__(self, sample_rate=16000, n_fft=1024, win_length=640, hop_length=320,
mel_fmin=10, mel_fmax=None, num_mels=128, **kwargs):
super().__init__(**kwargs)
self.sample_rate = sample_rate
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.num_mels = num_mels
class SparkTTSEncoderConfig(PretrainedConfig):
"""Configuration for the BiCodec Feature Encoder."""
model_type = "spark-tts-encoder"
def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
vocos_num_layers=12, out_channels=1024, sample_ratios=[1, 1], **kwargs):
super().__init__(**kwargs)
self.input_channels = input_channels
self.vocos_dim = vocos_dim
self.vocos_intermediate_dim = vocos_intermediate_dim
self.vocos_num_layers = vocos_num_layers
self.out_channels = out_channels
self.sample_ratios = sample_ratios
class SparkTTSDecoderConfig(PretrainedConfig):
"""Configuration for the BiCodec Wave Generator (Decoder)."""
model_type = "spark-tts-decoder"
def __init__(self, input_channel=1024, channels=1536, rates=[8, 5, 4, 2],
kernel_sizes=[16, 11, 8, 4], **kwargs):
super().__init__(**kwargs)
self.input_channel = input_channel
self.channels = channels
self.rates = rates
self.kernel_sizes = kernel_sizes
class SparkTTSQuantizerConfig(PretrainedConfig):
"""Configuration for the BiCodec Factorized Vector Quantizer."""
model_type = "spark-tts-quantizer"
def __init__(self, input_dim=1024, codebook_size=8192, codebook_dim=8,
commitment=0.25, codebook_loss_weight=2.0, decay=0.99,
threshold_ema_dead_code=0.2, **kwargs):
# Note: Removed use_l2_normlize as it wasn't in the original class __init__ args
# Add it back if it's actually used by the FactorizedVectorQuantize class init
super().__init__(**kwargs)
self.input_dim = input_dim
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.commitment = commitment
self.codebook_loss_weight = codebook_loss_weight
self.decay = decay
self.threshold_ema_dead_code = threshold_ema_dead_code
class SparkTTSSpeakerEncoderConfig(PretrainedConfig):
"""Configuration for the BiCodec Speaker Encoder."""
model_type = "spark-tts-speaker-encoder"
def __init__(self, input_dim=128, out_dim=1024, latent_dim=128, token_num=32,
fsq_levels=[4, 4, 4, 4, 4, 4], fsq_num_quantizers=1, **kwargs):
super().__init__(**kwargs)
self.input_dim = input_dim
self.out_dim = out_dim
self.latent_dim = latent_dim
self.token_num = token_num
self.fsq_levels = fsq_levels
self.fsq_num_quantizers = fsq_num_quantizers
class SparkTTSPrenetConfig(PretrainedConfig):
"""Configuration for the BiCodec Prenet."""
model_type = "spark-tts-prenet"
def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
vocos_num_layers=12, out_channels=1024, condition_dim=1024,
sample_ratios=[1, 1], use_tanh_at_final=False, **kwargs):
super().__init__(**kwargs)
self.input_channels = input_channels
self.vocos_dim = vocos_dim
self.vocos_intermediate_dim = vocos_intermediate_dim
self.vocos_num_layers = vocos_num_layers
self.out_channels = out_channels
self.condition_dim = condition_dim
self.sample_ratios = sample_ratios
self.use_tanh_at_final = use_tanh_at_final
class SparkTTSPostnetConfig(PretrainedConfig):
"""Configuration for the BiCodec Postnet."""
model_type = "spark-tts-postnet"
def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
vocos_num_layers=6, out_channels=1024, use_tanh_at_final=False, **kwargs):
# Note: Removed condition_dim as it wasn't in the original config example for postnet
super().__init__(**kwargs)
self.input_channels = input_channels
self.vocos_dim = vocos_dim
self.vocos_intermediate_dim = vocos_intermediate_dim
self.vocos_num_layers = vocos_num_layers
self.out_channels = out_channels
self.use_tanh_at_final = use_tanh_at_final
# --- Define the Intermediate BiCodec Config Class ---
class SparkTTSBiCodecConfig(PretrainedConfig):
"""
Intermediate configuration class for the BiCodec component within SparkTTS.
It holds instances of the individual sub-component configurations.
"""
model_type = "spark-tts-bicodec"
# Map keys in the 'bicodec_config' dict to their respective classes
sub_configs = {
"mel_params": SparkTTSMelParamsConfig,
"encoder_config": SparkTTSEncoderConfig,
"decoder_config": SparkTTSDecoderConfig,
"quantizer_config": SparkTTSQuantizerConfig,
"speaker_encoder_config": SparkTTSSpeakerEncoderConfig,
"prenet_config": SparkTTSPrenetConfig,
"postnet_config": SparkTTSPostnetConfig,
}
def __init__(
self,
mel_params=None,
encoder_config=None,
decoder_config=None,
quantizer_config=None,
speaker_encoder_config=None,
prenet_config=None,
postnet_config=None,
**kwargs,
):
super().__init__(**kwargs)
# Instantiate sub-configs from dictionaries or use defaults/provided instances
self.mel_params = self._init_sub_config(mel_params, "mel_params")
self.encoder_config = self._init_sub_config(encoder_config, "encoder_config")
self.decoder_config = self._init_sub_config(decoder_config, "decoder_config")
self.quantizer_config = self._init_sub_config(quantizer_config, "quantizer_config")
self.speaker_encoder_config = self._init_sub_config(speaker_encoder_config, "speaker_encoder_config")
self.prenet_config = self._init_sub_config(prenet_config, "prenet_config")
self.postnet_config = self._init_sub_config(postnet_config, "postnet_config")
def _init_sub_config(self, config_input, config_key):
"""Helper to initialize sub-configs."""
config_cls = self.sub_configs[config_key]
if isinstance(config_input, dict):
return config_cls(**config_input)
elif config_input is None:
return config_cls() # Initialize with defaults
elif isinstance(config_input, config_cls):
return config_input # Already an instance
else:
raise TypeError(f"Invalid type for {config_key}: {type(config_input)}. Expected dict, None, or {config_cls.__name__}.")
# --- Define the Main SparkTTS Config Class ---
class SparkTTSConfig(PretrainedConfig):
r"""
Main configuration class for SparkTTSModel, including nested BiCodec configuration.
Args:
llm_model_name_or_path (`str`, *optional*, defaults to `"./LLM"`): Path/ID for LLM.
bicodec_model_name_or_path (`str`, *optional*, defaults to `"./BiCodec"`): Path/ID for BiCodec checkpoint.
wav2vec2_model_name_or_path (`str`, *optional*, defaults to `"./wav2vec2-large-xlsr-53"`): Path/ID for Wav2Vec2.
sample_rate (`int`, *optional*, defaults to 16000): Audio sample rate.
# ... (other top-level args: highpass_cutoff_freq, latent_hop_length, ref_segment_duration, volume_normalize) ...
bicodec_config (`dict`, *optional*): Dictionary to initialize `SparkTTSBiCodecConfig`.
torch_dtype (`str`, *optional*, defaults to `"auto"`): Torch dtype.
kwargs (*optional*): Dictionary of keyword arguments.
"""
model_type = "spark-tts"
# Map the key in config.json to the intermediate BiCodec config class
sub_configs = {"bicodec_config": SparkTTSBiCodecConfig}
attribute_map = {"hidden_size": "d_model"} # Example
def __init__(
self,
llm_model_name_or_path="./LLM",
bicodec_model_name_or_path="./BiCodec",
wav2vec2_model_name_or_path="./wav2vec2-large-xlsr-53",
sample_rate=16000,
highpass_cutoff_freq=40,
latent_hop_length=320,
ref_segment_duration=6.0,
volume_normalize=True,
bicodec_config=None, # Expects a dictionary or None
torch_dtype="auto",
**kwargs,
):
# --- Top-level parameters ---
self.llm_model_name_or_path = llm_model_name_or_path
self.bicodec_model_name_or_path = bicodec_model_name_or_path
self.wav2vec2_model_name_or_path = wav2vec2_model_name_or_path
self.sample_rate = sample_rate
self.highpass_cutoff_freq = highpass_cutoff_freq
self.latent_hop_length = latent_hop_length
self.ref_segment_duration = ref_segment_duration
self.volume_normalize = volume_normalize
self.torch_dtype = torch_dtype
# --- Nested BiCodec Configuration ---
# Instantiate the intermediate BiCodec config class, which will handle its own sub-configs
if isinstance(bicodec_config, dict):
self.bicodec_config = self.sub_configs["bicodec_config"](**bicodec_config)
elif bicodec_config is None:
logger.info("`bicodec_config` not provided. Initializing `SparkTTSBiCodecConfig` with its defaults.")
self.bicodec_config = self.sub_configs["bicodec_config"]()
elif isinstance(bicodec_config, self.sub_configs["bicodec_config"]):
self.bicodec_config = bicodec_config # Use existing instance
else:
raise TypeError(f"Invalid type for bicodec_config: {type(bicodec_config)}. Expected dict, None, or SparkTTSBiCodecConfig.")
# Set processor class and auto_map
kwargs["processor_class"] = kwargs.get("processor_class", "SparkTTSProcessor")
kwargs["auto_map"] = kwargs.get("auto_map", {
"AutoConfig": "configuration_spark_tts.SparkTTSConfig",
"AutoModel": "modeling_spark_tts.SparkTTSModel",
"AutoProcessor": "processing_spark_tts.SparkTTSProcessor"
})
super().__init__(**kwargs) |