xtts2-gpt / xtts2_gpt_modeling.py
mlinmg's picture
Upload 8 files
7eebd5c verified
raw
history blame
11.5 kB
import functools
import math
from array import array
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import List, Optional, Union, Iterable, Tuple, Mapping
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_pp_group
from vllm.inputs import InputContext, INPUT_REGISTRY
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.gpt2 import GPT2Block
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
from vllm.sequence import IntermediateTensors, SequenceData, VLLM_TOKEN_ID_ARRAY_TYPE
from vllm.model_executor.models.interfaces import SupportsMultiModal
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder # noqa
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler # noqa
from TTS.TTS.tts.layers.xtts.gpt import LearnedPositionEmbeddings
# Constants for token calculation
_AUDIO_PLACEHOLDER_TOKEN = 8192 # Using XTTS start_audio_token as placeholder
_AUDIO_TOKENS_PER_SECOND = 6.25
_CODE_STRIDE_LEN = 1024
def get_xtts_max_audio_tokens(ctx: InputContext) -> int:
"""Calculate maximum audio tokens based on text context and audio duration."""
# Based on GPT config and common XTTS settings
text_context = ctx.model_config.max_seq_len - 100 # Reserve space for text
# Allow for ~30 seconds of audio (similar to whisper chunks)
max_audio_duration = 30.0
audio_tokens = math.ceil(max_audio_duration * _AUDIO_TOKENS_PER_SECOND)
total_tokens = text_context + audio_tokens + 4 # +4 for special tokens
return min(total_tokens, 1000) # Cap at 1000 tokens as specified
def dummy_seq_data_for_xtts(
ctx: InputContext,
seq_len: int,
audio_count: int,
) -> SequenceData:
"""Create dummy sequence data for XTTS profiling."""
# Calculate audio token space needed
audio_len_tokens = math.ceil(_AUDIO_TOKENS_PER_SECOND * 5) # Assume 5s per chunk
audio_placeholder = array(
VLLM_TOKEN_ID_ARRAY_TYPE,
[_AUDIO_PLACEHOLDER_TOKEN]
) * audio_len_tokens
# Add separator between chunks
audio_token_ids = (audio_placeholder + array(VLLM_TOKEN_ID_ARRAY_TYPE, [0])) * audio_count
# Fill remaining sequence with padding
other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - len(audio_token_ids))
return SequenceData(audio_token_ids + other_token_ids)
def dummy_conditioning_for_xtts(
ctx: InputContext,
audio_count: int,
) -> dict:
"""Create dummy conditioning data for XTTS."""
return {
"cond_latents": [(torch.zeros(80, 1024), 22050) for _ in range(audio_count)]
}
def dummy_data_for_xtts(
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Tuple[SequenceData, dict]:
"""Create complete dummy data for XTTS profiling."""
audio_count = mm_counts["audio"]
seq_data = dummy_seq_data_for_xtts(ctx, seq_len, audio_count)
cond_data = dummy_conditioning_for_xtts(ctx, audio_count)
return (seq_data, cond_data)
def input_mapper_for_xtts(ctx: InputContext, data: object) -> MultiModalInputs:
"""Map input data to XTTS format."""
if not isinstance(data, list):
data = [data]
# Each item should be a tuple of (mel_spec, sample_rate)
for audio_input in data:
if not isinstance(audio_input, tuple):
raise NotImplementedError(f"Unsupported data type: {type(audio_input)}")
return MultiModalInputs({"cond_latents": data})
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_xtts)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens("audio", get_xtts_max_audio_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_xtts)
class XttsGPT(nn.Module, SupportsMultiModal):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional["QuantizationConfig"] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
# XTTS specific components
self.conditioning_encoder = ConditioningEncoder(
80, config.n_embd, num_attn_heads=config.n_head
)
if config.use_perceiver_resampler:
self.conditioning_perceiver = PerceiverResampler(
dim=config.n_embd,
depth=2,
dim_context=config.n_embd,
num_latents=32,
dim_head=64,
heads=8,
ff_mult=4,
use_flash_attn=False,
)
# Core GPT components following VLLM pattern
self.gpt = XttsGPT2Model(
config,
cache_config,
quant_config,
prefix="gpt"
)
# Prediction heads
self.text_head = ColumnParallelLinear(
config.n_embd,
config.vocab_size,
bias=False,
quant_config=quant_config,
prefix="text_head"
)
self.mel_head = ColumnParallelLinear(
config.n_embd,
config.num_audio_tokens,
bias=False,
quant_config=quant_config,
prefix="mel_head"
)
self.sampler = Sampler()
def get_style_emb(self, cond_input: torch.Tensor, return_latent: bool = False) -> torch.Tensor:
"""Get conditioning embeddings from mel spectrograms."""
if not return_latent:
if cond_input.ndim == 4:
cond_input = cond_input.squeeze(1)
conds = self.conditioning_encoder(cond_input)
if hasattr(self, 'conditioning_perceiver'):
conds = self.conditioning_perceiver(
conds.permute(0, 2, 1)
).transpose(1, 2)
else:
conds = cond_input.unsqueeze(1)
return conds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
cond_latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass following VLLM pattern."""
if cond_latents is not None:
# Combine conditioning with input embeddings
input_embeds = self.gpt.get_input_embeddings()(input_ids)
combined_embeds = torch.cat([cond_latents, input_embeds], dim=1)
hidden_states = self.gpt(
inputs_embeds=combined_embeds,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
)
else:
hidden_states = self.gpt(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
)
return hidden_states
def compute_logits( # useless but kept for compatibility
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
"""Compute output logits."""
text_logits = self.text_head(hidden_states[sampling_metadata.selected_token_indices])
mel_logits = self.mel_head(hidden_states[sampling_metadata.selected_token_indices])
return torch.cat([text_logits, mel_logits], dim=1)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
"""Sample next tokens using VLLM sampler."""
return self.sampler(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"""Load weights following VLLM pattern."""
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if name not in params_dict:
continue
param = params_dict[name]
if "c_attn" in name or "c_proj" in name or "c_fc" in name:
if name.endswith(".weight"):
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
class XttsGPT2Model(nn.Module):
"""VLLM-style implementation of GPT2 core architecture."""
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional["QuantizationConfig"] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.text_embedding = VocabParallelEmbedding(config.number_text_tokens, config.n_embd)
self.mel_embedding = VocabParallelEmbedding(config.num_audio_tokens, config.n_embd)
self.text_pos_embedding = (
LearnedPositionEmbeddings(config.max_text_seq_len, config.n_embd)
if config.max_mel_seq_len != -1
else functools.partial(config.null_position_embeddings, dim=config.n_embd)
)
self.mel_pos_embedding = (
LearnedPositionEmbeddings(config.max_mel_seq_len, config.n_embd)
if config.max_mel_seq_len != -1
else functools.partial(config.null_position_embeddings, dim=config.n_embd)
)
# Build gpt blocks
self.h = nn.ModuleList([
GPT2Block(
config,
cache_config,
quant_config,
prefix=f"{prefix}.h.{i}"
) for i in range(config.num_hidden_layers)
])
self.final_norm = nn.LayerNorm(
config.n_embd,
eps=config.layer_norm_epsilon
)
def forward( # TODO: this is not correct, allieeate it with the correct implementation
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.h[i]
hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states)
return hidden_states