mlinmg commited on
Commit
7e1e475
1 Parent(s): 09a868c

Upload 2 files

Browse files
Files changed (2) hide show
  1. config.json +9 -78
  2. gpt_config.py +83 -189
config.json CHANGED
@@ -1,109 +1,40 @@
1
  {
2
- "_name_or_path": "AstraMindAI/xtts2-gpt",
3
  "architectures": [
4
  "XttsGPT"
5
  ],
6
- "torch_dtype": "float32",
7
- "auto_map": {
8
- "AutoConfig": "AstraMindAI/xtts2-gpt--gpt_config.XTTSGPTConfig",
9
- "AutoModelForCausalLM": "AstraMindAI/xtts2-gpt--xtts2_gpt_modeling.XttsGPT",
10
- "AutoTokenizer": "AstraMindAI/xtts2-gpt--tokenizer.XTTSTokenizerFast"
11
- },
12
- "activation_function": "gelu",
13
- "attn_pdrop": 0.1,
14
  "audio_config": {
15
- "fmax": 8000,
16
- "fmin": 0,
17
- "hop_length": 256,
18
  "mel_channels": 80,
19
- "mel_norms_file": null,
20
- "n_fft": 1024,
21
  "output_sample_rate": 24000,
22
- "power": 1.0,
23
- "sample_rate": 22050,
24
- "win_length": 1024
25
  },
26
- "batch_size": 32,
27
- "char_limits": {
28
- "ar": 166,
29
- "cs": 186,
30
- "de": 253,
31
- "en": 250,
32
- "es": 239,
33
- "fr": 273,
34
- "hu": 224,
35
- "it": 213,
36
- "ja": 71,
37
- "ko": 95,
38
- "nl": 251,
39
- "pl": 224,
40
- "pt": 203,
41
- "ru": 182,
42
- "tr": 226,
43
- "zh": 82
44
  },
45
- "checkpointing": false,
46
- "clvp_checkpoint": null,
47
- "code_stride_len": 1024,
48
- "cond_chunk_len": 4,
49
- "cond_d_vector_in_each_upsampling_layer": true,
50
- "cond_len": 30,
51
- "d_vector_dim": 512,
52
- "decoder_checkpoint": null,
53
  "decoder_input_dim": 1024,
54
- "duration_const": 102400,
55
- "embd_pdrop": 0.1,
56
  "enable_redaction": false,
 
 
57
  "hidden_size": 1024,
58
- "input_sample_rate": 22050,
59
  "kv_cache": true,
60
- "label_smoothing": 0.0,
61
- "languages": [
62
- "en",
63
- "es",
64
- "fr",
65
- "de",
66
- "it",
67
- "pt",
68
- "pl",
69
- "tr",
70
- "ru",
71
- "nl",
72
- "cs",
73
- "ar",
74
- "zh-cn",
75
- "hu",
76
- "ko",
77
- "ja",
78
- "hi"
79
- ],
80
  "layer_norm_epsilon": 1e-05,
81
  "max_audio_tokens": 605,
82
- "max_position_embeddings": 2048,
83
  "max_prompt_tokens": 70,
84
- "max_ref_len": 30,
85
  "max_text_tokens": 402,
86
  "model_type": "xtts_gpt",
87
- "n_inner": null,
88
  "num_attention_heads": 16,
89
- "num_chars": 255,
90
  "num_hidden_layers": 30,
91
  "number_text_tokens": 6681,
92
- "output_hop_length": 256,
93
- "output_sample_rate": 24000,
94
- "perceiver_cond_length_compression": 256,
95
  "reorder_and_upcast_attn": false,
96
- "resid_pdrop": 0.1,
97
  "scale_attn_by_inverse_layer_idx": false,
98
- "sound_norm_refs": false,
99
  "start_audio_token": 1024,
100
  "start_text_token": null,
101
  "stop_audio_token": 1025,
102
  "stop_text_token": null,
103
- "tokenizer_file": "",
104
- "train_solo_embeddings": false,
105
  "transformers_version": "4.46.0",
106
  "use_masking_gt_prompt_approach": true,
107
  "use_perceiver_resampler": true,
108
- "vocab_size": 1026
109
  }
 
1
  {
 
2
  "architectures": [
3
  "XttsGPT"
4
  ],
 
 
 
 
 
 
 
 
5
  "audio_config": {
 
 
 
6
  "mel_channels": 80,
 
 
7
  "output_sample_rate": 24000,
8
+ "sample_rate": 22050
 
 
9
  },
10
+ "auto_map": {
11
+ "AutoConfig": "AstraMindAI/xtts2-gpt--gpt_config.XTTSGPTConfig",
12
+ "AutoModelForCausalLM": "AstraMindAI/xtts2-gpt--xtts2_gpt_modeling.XttsGPT"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  },
 
 
 
 
 
 
 
 
14
  "decoder_input_dim": 1024,
 
 
15
  "enable_redaction": false,
16
+ "gpt_batch_size": 1,
17
+ "gpt_max_audio_tokens": 605,
18
  "hidden_size": 1024,
19
+ "initializer_range": 0.02,
20
  "kv_cache": true,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  "layer_norm_epsilon": 1e-05,
22
  "max_audio_tokens": 605,
 
23
  "max_prompt_tokens": 70,
 
24
  "max_text_tokens": 402,
25
  "model_type": "xtts_gpt",
 
26
  "num_attention_heads": 16,
27
+ "num_audio_tokens": 1026,
28
  "num_hidden_layers": 30,
29
  "number_text_tokens": 6681,
 
 
 
30
  "reorder_and_upcast_attn": false,
 
31
  "scale_attn_by_inverse_layer_idx": false,
 
32
  "start_audio_token": 1024,
33
  "start_text_token": null,
34
  "stop_audio_token": 1025,
35
  "stop_text_token": null,
 
 
36
  "transformers_version": "4.46.0",
37
  "use_masking_gt_prompt_approach": true,
38
  "use_perceiver_resampler": true,
39
+ "vocab_size": 6681
40
  }
gpt_config.py CHANGED
@@ -5,6 +5,14 @@ from transformers.utils import logging
5
 
6
  logger = logging.get_logger(__name__)
7
 
 
 
 
 
 
 
 
 
8
  @dataclass
9
  class XTTSAudioConfig:
10
  """Configuration for audio processing parameters"""
@@ -19,226 +27,112 @@ class XTTSAudioConfig:
19
  power: float = 1.0
20
  mel_norms_file: Optional[str] = None
21
 
 
22
  class XTTSGPTConfig(PretrainedConfig):
23
- """Configuration class for the GPT component of XTTS with automatic legacy conversion"""
24
  model_type = "xtts_gpt"
25
 
26
  def __init__(
27
  self,
28
  # Model architecture
29
- vocab_size: int = 1026, # num_audio_tokens
30
- hidden_size: int = 1024, # Changed from gpt_n_model_channels
31
- num_hidden_layers: int = 30, # Changed from gpt_layers
32
- num_attention_heads: int = 16, # Changed from gpt_n_heads
33
- n_inner: Optional[int] = None, # Added for GPT-2 compatibility
34
- max_position_embeddings: int = 2048, # Added for positional embeddings
35
- layer_norm_epsilon: float = 1e-5, # Added for layer norm
36
- activation_function: str = "gelu", # Added activation function
37
- resid_pdrop: float = 0.1, # Added dropout rates
38
- embd_pdrop: float = 0.1,
39
- attn_pdrop: float = 0.1,
40
-
41
- # Specific XTTS parameters
42
- num_chars: int = 255,
43
- batch_size: int = 1, # Changed from gpt_batch_size
44
- max_audio_tokens: int = 605, # Changed from gpt_max_audio_tokens
45
- max_text_tokens: int = 402, # Changed from gpt_max_text_tokens
46
- max_prompt_tokens: int = 70, # Changed from gpt_max_prompt_tokens
47
- number_text_tokens: int = 6681, # Changed from gpt_number_text_tokens
48
- start_text_token: Optional[int] = None, # Changed from gpt_start_text_token
49
- stop_text_token: Optional[int] = None, # Changed from gpt_stop_text_token
50
- start_audio_token: int = 1024, # Changed from gpt_start_audio_token
51
- stop_audio_token: int = 1025, # Changed from gpt_stop_audio_token
52
- code_stride_len: int = 1024, # Changed from gpt_code_stride_len
53
- use_masking_gt_prompt_approach: bool = True, # Changed from gpt_use_masking_gt_prompt_approach
54
- use_perceiver_resampler: bool = True, # Changed from gpt_use_perceiver_resampler
55
- checkpointing: bool = False, # Changed from gpt_checkpointing
56
- train_solo_embeddings: bool = False, # Changed from gpt_train_solo_embeddings
57
-
58
- # Training parameters
59
- enable_redaction: bool = False,
60
  kv_cache: bool = True,
61
- perceiver_cond_length_compression: int = 256,
62
- label_smoothing: float = 0.0,
63
 
64
- # Generation parameters
65
- cond_len: int = 30, # Changed from gpt_cond_len
66
- cond_chunk_len: int = 4, # Changed from gpt_cond_chunk_len
67
- max_ref_len: int = 30,
68
- sound_norm_refs: bool = False,
69
 
70
  # Audio processing
71
- audio_config: Optional[XTTSAudioConfig] = None,
72
-
73
- # Constants and limits
74
- duration_const: int = 102400,
75
- char_limits: Optional[Dict[str, int]] = None,
76
- languages: Optional[List[str]] = None,
77
 
78
-
79
- # GPT-2 compatibility flags
 
 
80
  scale_attn_by_inverse_layer_idx: bool = False,
81
  reorder_and_upcast_attn: bool = False,
82
- add_cross_attention: bool = False,
83
- tie_word_embeddings: bool = True,
 
 
 
 
 
 
84
  **kwargs
85
  ):
86
- # Handle legacy config conversion
87
- if any(k.startswith('gpt_') for k in kwargs):
88
- kwargs = self._convert_legacy_config(kwargs)
89
-
90
- if 'model_args' in kwargs:
91
- kwargs = self._convert_legacy_config(kwargs['model_args'])
92
-
93
- # Initialize audio config
94
- if audio_config is None:
95
- audio_config = XTTSAudioConfig()
96
- elif isinstance(audio_config, dict):
97
- audio_config = XTTSAudioConfig(**audio_config)
98
-
99
- # Set default char limits
100
- if char_limits is None:
101
- char_limits = {
102
- "en": 250, "de": 253, "fr": 273, "es": 239,
103
- "it": 213, "pt": 203, "pl": 224, "zh": 82,
104
- "ar": 166, "cs": 186, "ru": 182, "nl": 251,
105
- "tr": 226, "ja": 71, "hu": 224, "ko": 95,
106
- }
107
-
108
- # Set default languages
109
- if languages is None:
110
- languages = [
111
- "en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl",
112
- "cs", "ar", "zh-cn", "hu", "ko", "ja", "hi"
113
- ]
114
-
115
- super().__init__(
116
- pad_token_id=kwargs.pop('pad_token_id', None),
117
- bos_token_id=kwargs.pop('bos_token_id', None),
118
- eos_token_id=kwargs.pop('eos_token_id', None),
119
- **kwargs
120
  )
121
 
122
- # Set all attributes
123
- self.vocab_size = vocab_size
124
  self.hidden_size = hidden_size
125
  self.num_hidden_layers = num_hidden_layers
126
  self.num_attention_heads = num_attention_heads
127
- self.n_inner = n_inner
128
- self.max_position_embeddings = max_position_embeddings
129
- self.layer_norm_epsilon = layer_norm_epsilon
130
- self.activation_function = activation_function
131
- self.resid_pdrop = resid_pdrop
132
- self.embd_pdrop = embd_pdrop
133
- self.attn_pdrop = attn_pdrop
134
-
135
- # XTTS specific
136
- self.num_chars = num_chars
137
- self.batch_size = batch_size
138
- self.max_audio_tokens = max_audio_tokens
139
- self.max_text_tokens = max_text_tokens
140
- self.max_prompt_tokens = max_prompt_tokens
141
  self.number_text_tokens = number_text_tokens
142
  self.start_text_token = start_text_token
143
  self.stop_text_token = stop_text_token
 
 
144
  self.start_audio_token = start_audio_token
145
  self.stop_audio_token = stop_audio_token
146
- self.code_stride_len = code_stride_len
 
 
 
 
 
147
  self.use_masking_gt_prompt_approach = use_masking_gt_prompt_approach
148
  self.use_perceiver_resampler = use_perceiver_resampler
149
- self.checkpointing = checkpointing
150
- self.train_solo_embeddings = train_solo_embeddings
151
-
152
- # Training
153
- self.enable_redaction = enable_redaction
154
  self.kv_cache = kv_cache
155
- self.perceiver_cond_length_compression = perceiver_cond_length_compression
156
- self.label_smoothing = label_smoothing
157
-
158
- # Generation
159
- self.cond_len = cond_len
160
- self.cond_chunk_len = cond_chunk_len
161
- self.max_ref_len = max_ref_len
162
- self.sound_norm_refs = sound_norm_refs
163
-
164
- # Audio and other
165
- self.audio_config = audio_config
166
- self.duration_const = duration_const
167
- self.char_limits = char_limits
168
- self.languages = languages
169
-
170
- # GPT-2 flags
171
  self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
172
  self.reorder_and_upcast_attn = reorder_and_upcast_attn
173
- self.add_cross_attention = add_cross_attention
174
- self.tie_word_embeddings = tie_word_embeddings
175
-
176
- @staticmethod
177
- def _convert_legacy_config(config_dict: Dict) -> Dict:
178
- """Converts legacy config format to new format."""
179
- mapping = {
180
- 'gpt_batch_size': 'batch_size',
181
- 'gpt_max_audio_tokens': 'max_audio_tokens',
182
- 'gpt_max_text_tokens': 'max_text_tokens',
183
- 'gpt_max_prompt_tokens': 'max_prompt_tokens',
184
- 'gpt_layers': 'num_hidden_layers',
185
- 'gpt_n_model_channels': 'hidden_size',
186
- 'gpt_n_heads': 'num_attention_heads',
187
- 'gpt_number_text_tokens': 'number_text_tokens',
188
- 'gpt_start_text_token': 'start_text_token',
189
- 'gpt_stop_text_token': 'stop_text_token',
190
- 'gpt_num_audio_tokens': 'vocab_size',
191
- 'gpt_start_audio_token': 'start_audio_token',
192
- 'gpt_stop_audio_token': 'stop_audio_token',
193
- 'gpt_code_stride_len': 'code_stride_len',
194
- 'gpt_use_masking_gt_prompt_approach': 'use_masking_gt_prompt_approach',
195
- 'gpt_use_perceiver_resampler': 'use_perceiver_resampler',
196
- 'gpt_checkpointing': 'checkpointing',
197
- 'gpt_train_solo_embeddings': 'train_solo_embeddings',
198
- 'gpt_cond_len': 'cond_len',
199
- 'gpt_cond_chunk_len': 'cond_chunk_len'
200
- }
201
-
202
- new_config = {}
203
-
204
- # Convert keys
205
- for old_key, new_key in mapping.items():
206
- if old_key in config_dict:
207
- new_config[new_key] = config_dict[old_key]
208
-
209
- # Copy non-mapped keys
210
- for k, v in config_dict.items():
211
- if not k.startswith('gpt_') and k not in new_config:
212
- new_config[k] = v
213
-
214
- # Handle audio config
215
- if 'input_sample_rate' in config_dict or 'output_sample_rate' in config_dict:
216
- audio_config = {
217
- 'sample_rate': config_dict.get('input_sample_rate', 22050),
218
- 'output_sample_rate': config_dict.get('output_sample_rate', 24000),
219
- 'hop_length': config_dict.get('output_hop_length', 256)
220
- }
221
- new_config['audio_config'] = audio_config
222
-
223
- return new_config
224
 
225
  def to_dict(self) -> Dict:
226
- """Convert config to dictionary"""
227
- config_dict = super().to_dict()
228
- config_dict["audio_config"] = asdict(self.audio_config)
229
- return config_dict
230
 
231
  @classmethod
232
- def from_dict(cls, config_dict: Dict, **kwargs) -> 'XTTSGPTConfig':
233
- """Create config from dictionary"""
234
- if isinstance(config_dict.get("audio_config"), dict):
235
- audio_config = XTTSAudioConfig(**config_dict["audio_config"])
236
- config_dict["audio_config"] = audio_config
237
- return cls(**config_dict, **kwargs)
238
-
239
- def update_with_tokenizer(self, tokenizer=None):
240
- """Update configuration values based on tokenizer"""
241
- if tokenizer is not None:
242
- self.number_text_tokens = tokenizer.get_vocab_size()
243
- self.start_text_token = tokenizer.bos_token_id
244
- self.stop_text_token = tokenizer.eos_token_id
 
5
 
6
  logger = logging.get_logger(__name__)
7
 
8
+
9
+ @dataclass
10
+ class GPTAudioConfig:
11
+ """Configuration for GPT audio processing parameters"""
12
+ mel_channels: int = 80
13
+ sample_rate: int = 22050
14
+ output_sample_rate: int = 24000
15
+
16
  @dataclass
17
  class XTTSAudioConfig:
18
  """Configuration for audio processing parameters"""
 
27
  power: float = 1.0
28
  mel_norms_file: Optional[str] = None
29
 
30
+
31
  class XTTSGPTConfig(PretrainedConfig):
32
+ """Configuration class for the GPT component of XTTS."""
33
  model_type = "xtts_gpt"
34
 
35
  def __init__(
36
  self,
37
  # Model architecture
38
+ hidden_size: int = 1024, # gpt_n_model_channels in original
39
+ num_hidden_layers: int = 30, # gpt_layers in original
40
+ num_attention_heads: int = 16, # gpt_n_heads in original
41
+
42
+ # Tokenizer settings
43
+ vocab_size: int = 6681, # gpt_number_text_tokens in original
44
+ number_text_tokens: int = 6681, # Explicit text token vocabulary size
45
+ start_text_token: Optional[int] = None,
46
+ stop_text_token: Optional[int] = None,
47
+
48
+ # Audio token settings
49
+ num_audio_tokens: int = 1026, # gpt_num_audio_tokens in original
50
+ start_audio_token: int = 1024, # gpt_start_audio_token in original
51
+ stop_audio_token: int = 1025, # gpt_stop_audio_token in original
52
+
53
+ # Sequence length settings
54
+ max_audio_tokens: int = 605, # gpt_max_audio_tokens in original
55
+ max_text_tokens: int = 402, # gpt_max_text_tokens in original
56
+ max_prompt_tokens: int = 70, # gpt_max_prompt_tokens in original
57
+ gpt_max_audio_tokens: int = 605, # Used for generation
58
+
59
+ # Model behavior settings
60
+ use_masking_gt_prompt_approach: bool = True, # gpt_use_masking_gt_prompt_approach in original
61
+ use_perceiver_resampler: bool = True, # gpt_use_perceiver_resampler in original
 
 
 
 
 
 
 
62
  kv_cache: bool = True,
63
+ enable_redaction: bool = False,
 
64
 
65
+ # GPT batch settings
66
+ gpt_batch_size: int = 1,
 
 
 
67
 
68
  # Audio processing
69
+ audio_config: Optional[Dict] = None,
 
 
 
 
 
70
 
71
+ # Architecture specifics
72
+ layer_norm_epsilon: float = 1e-5,
73
+ initializer_range: float = 0.02,
74
+ add_cross_attention: bool = False,
75
  scale_attn_by_inverse_layer_idx: bool = False,
76
  reorder_and_upcast_attn: bool = False,
77
+
78
+ # Size settings for the decoder
79
+ decoder_input_dim: int = 1024,
80
+ architectures=["XttsGPT"],
81
+ auto_map={
82
+ "AutoConfig": "AstraMindAI/xtts2-gpt--gpt_config.XTTSGPTConfig",
83
+ "AutoModelForCausalLM": "AstraMindAI/xtts2-gpt--xtts2_gpt_modeling.XttsGPT",
84
+ },
85
  **kwargs
86
  ):
87
+ super().__init__(**kwargs)
88
+ self.architectures = architectures
89
+ self.auto_map = auto_map
90
+ self.audio_config = GPTAudioConfig(
91
+ **audio_config if audio_config is not None else {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  )
93
 
 
 
94
  self.hidden_size = hidden_size
95
  self.num_hidden_layers = num_hidden_layers
96
  self.num_attention_heads = num_attention_heads
97
+
98
+ self.vocab_size = vocab_size
 
 
 
 
 
 
 
 
 
 
 
 
99
  self.number_text_tokens = number_text_tokens
100
  self.start_text_token = start_text_token
101
  self.stop_text_token = stop_text_token
102
+
103
+ self.num_audio_tokens = num_audio_tokens
104
  self.start_audio_token = start_audio_token
105
  self.stop_audio_token = stop_audio_token
106
+
107
+ self.max_audio_tokens = max_audio_tokens
108
+ self.max_text_tokens = max_text_tokens
109
+ self.max_prompt_tokens = max_prompt_tokens
110
+ self.gpt_max_audio_tokens = gpt_max_audio_tokens
111
+
112
  self.use_masking_gt_prompt_approach = use_masking_gt_prompt_approach
113
  self.use_perceiver_resampler = use_perceiver_resampler
 
 
 
 
 
114
  self.kv_cache = kv_cache
115
+ self.enable_redaction = enable_redaction
116
+
117
+ self.gpt_batch_size = gpt_batch_size
118
+
119
+ self.layer_norm_epsilon = layer_norm_epsilon
120
+ self.initializer_range = initializer_range
121
+ self.add_cross_attention = add_cross_attention
 
 
 
 
 
 
 
 
 
122
  self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
123
  self.reorder_and_upcast_attn = reorder_and_upcast_attn
124
+
125
+ self.decoder_input_dim = decoder_input_dim
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  def to_dict(self) -> Dict:
128
+ """Convert the config to a dictionary."""
129
+ output = super().to_dict()
130
+ output["audio_config"] = asdict(self.audio_config)
131
+ return output
132
 
133
  @classmethod
134
+ def from_dict(cls, config_dict: Dict, *args, **kwargs) -> "XTTSGPTConfig":
135
+ """Create a config from a dictionary."""
136
+ return cls(**config_dict)
137
+
138
+