| """Audio head for speech-to-speech using a frozen pretrained TTS backbone. |
| |
| Architecture: |
| Text → frozen LLM (SmolLM3-3B) → hidden states (llm_dim) |
| → Projector MLP (trainable, llm_dim → backbone_dim) |
| → Concat with codec embeddings → neutts-nano LlamaForCausalLM (frozen) |
| → lm_head → speech token logits → NeuCodec codes → audio |
| |
| The frozen LLM is loaded for standalone S2S training. When used inside a full |
| ASR pipeline (ASRModel), pre-computed LLM hidden states are passed directly |
| and the internal LLM is not used. |
| |
| neutts-nano (neuphonic/neutts-nano) is a pretrained 24-layer LlamaForCausalLM |
| (dim=576, ~117M params) that generates NeuCodec codes as <|speech_N|> tokens. |
| Only the projector MLP is trained. |
| |
| NeuCodec uses a single FSQ codebook (levels=[4]*8, vocab=65536) at 50 tokens/sec, |
| outputting 24kHz audio. Codes 0-65535 map to neutts-nano tokens <|speech_0|>..<|speech_65535|>. |
| """ |
|
|
| import logging |
| from dataclasses import dataclass |
| from typing import Iterator, Optional |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn import functional as F |
| from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, PreTrainedModel |
| from transformers.modeling_outputs import ModelOutput |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| NEUCODEC_VOCAB_SIZE = 65536 |
| NEUCODEC_SAMPLE_RATE = 24000 |
|
|
| |
| BOS_TOKEN = NEUCODEC_VOCAB_SIZE |
| EOS_TOKEN = NEUCODEC_VOCAB_SIZE + 1 |
| PAD_TOKEN = NEUCODEC_VOCAB_SIZE + 2 |
| TOTAL_VOCAB = NEUCODEC_VOCAB_SIZE + 3 |
|
|
|
|
| class AudioHeadConfig(PretrainedConfig): |
| """Configuration for AudioHead with frozen TTS backbone + trainable projector.""" |
|
|
| model_type = "audio_head" |
|
|
| def __init__( |
| self, |
| tts_model_id: str = "neuphonic/neutts-nano", |
| llm_model_id: str = "HuggingFaceTB/SmolLM3-3B", |
| projector_hidden: int = 1024, |
| max_audio_tokens: int = 500, |
| neucodec_model_id: str = "neuphonic/neucodec", |
| temperature: float = 1.0, |
| top_k: int = 50, |
| **kwargs, |
| ): |
| self.tts_model_id = tts_model_id |
| self.llm_model_id = llm_model_id |
| self.projector_hidden = projector_hidden |
| self.max_audio_tokens = max_audio_tokens |
| self.neucodec_model_id = neucodec_model_id |
| self.temperature = temperature |
| self.top_k = top_k |
| super().__init__(**kwargs) |
|
|
|
|
| @dataclass |
| class AudioHeadOutput(ModelOutput): |
| """Output of AudioHead forward pass. |
| |
| Attributes: |
| loss: Cross-entropy loss when codec_labels are provided. |
| codes: Generated NeuCodec codes in inference mode [batch, gen_len]. |
| """ |
|
|
| loss: Optional[torch.Tensor] = None |
| codes: Optional[torch.Tensor] = None |
|
|
|
|
| class AudioHead(PreTrainedModel): |
| """Frozen TTS backbone + trainable projector for speech generation. |
| |
| Loads neutts-nano (a pretrained LlamaForCausalLM that generates NeuCodec tokens) |
| and freezes it entirely. A frozen LLM converts text to hidden states, and a |
| trainable MLP projector maps those hidden states into neutts-nano's input space. |
| |
| Standalone training: text_token_ids → frozen LLM → hidden states → projector → backbone → speech codes |
| Pipeline inference: llm_hidden_states → projector → backbone → speech codes |
| """ |
|
|
| config_class = AudioHeadConfig |
| |
| |
| _supports_param_buffer_assignment = False |
|
|
| def __init__(self, config: AudioHeadConfig): |
| super().__init__(config) |
| self.max_tokens = config.max_audio_tokens |
|
|
| |
| |
| self._backbone_loaded = False |
| if not self._is_meta_init(): |
| self._load_backbone(config) |
|
|
| def _is_meta_init(self) -> bool: |
| """Check if we're inside a meta device context manager.""" |
| try: |
| test = torch.empty(1) |
| return test.device.type == "meta" |
| except Exception: |
| return False |
|
|
| def _load_backbone(self, config: AudioHeadConfig) -> None: |
| """Load the frozen TTS backbone, frozen LLM, and initialize the projector.""" |
| if self._backbone_loaded: |
| return |
|
|
| |
| logger.info("Loading TTS backbone: %s", config.tts_model_id) |
| self.backbone = AutoModelForCausalLM.from_pretrained( |
| config.tts_model_id, |
| torch_dtype=torch.bfloat16, |
| ) |
| self.backbone.requires_grad_(False) |
| self.backbone.eval() |
|
|
| |
| self.tts_tokenizer = AutoTokenizer.from_pretrained(config.tts_model_id) |
|
|
| |
| self.speech_token_offset = self.tts_tokenizer.convert_tokens_to_ids("<|speech_0|>") |
| self.speech_start_id = self.tts_tokenizer.convert_tokens_to_ids( |
| "<|SPEECH_GENERATION_START|>" |
| ) |
| self.speech_end_id = self.tts_tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>") |
|
|
| |
| |
| |
| logger.info("Loading frozen LLM: %s", config.llm_model_id) |
| self.llm = AutoModelForCausalLM.from_pretrained( |
| config.llm_model_id, |
| torch_dtype=torch.bfloat16, |
| ) |
| self.llm.requires_grad_(False) |
| self.llm.eval() |
|
|
| |
| |
| llm_tokenizer = AutoTokenizer.from_pretrained(config.llm_model_id, trust_remote_code=True) |
| prompt_enc = llm_tokenizer( |
| "Speak the following text aloud: ", |
| return_tensors="pt", |
| add_special_tokens=True, |
| ) |
| self.register_buffer( |
| "_prompt_prefix_ids", |
| prompt_enc.input_ids, |
| persistent=False, |
| ) |
| self._prompt_len = prompt_enc.input_ids.shape[1] |
|
|
| llm_dim = self.llm.config.hidden_size |
|
|
| |
| backbone_dim = self.backbone.config.hidden_size |
|
|
| |
| |
| |
| from transformers.models.llama.modeling_llama import LlamaRMSNorm |
|
|
| self.projector = nn.Sequential( |
| nn.Linear(llm_dim, config.projector_hidden), |
| LlamaRMSNorm(config.projector_hidden, eps=1e-6), |
| nn.GELU(), |
| nn.Linear(config.projector_hidden, backbone_dim), |
| LlamaRMSNorm(backbone_dim, eps=1e-6), |
| ).to(torch.bfloat16) |
|
|
| |
| self.temperature = config.temperature |
| self.top_k = config.top_k |
|
|
| |
| self.neucodec_model = None |
|
|
| self._backbone_loaded = True |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
| """Load AudioHead: config + projector weights from disk/Hub, backbone from HF Hub.""" |
| from pathlib import Path |
|
|
| from safetensors.torch import load_file |
|
|
| path = Path(pretrained_model_name_or_path) |
|
|
| |
| if not path.is_dir(): |
| from huggingface_hub import snapshot_download |
|
|
| path = Path(snapshot_download(pretrained_model_name_or_path)) |
|
|
| |
| config = AudioHeadConfig.from_pretrained(path) |
|
|
| |
| model = cls(config) |
|
|
| |
| safetensors_path = path / "model.safetensors" |
| if safetensors_path.exists(): |
| projector_state = load_file(safetensors_path) |
| model.load_state_dict(projector_state, strict=False) |
| logger.info("Loaded projector weights from %s", safetensors_path) |
|
|
| return model |
|
|
| def train(self, mode: bool = True): |
| """Override to keep backbone and LLM in eval mode (disables dropout, etc.).""" |
| super().train(mode) |
| |
| self.backbone.eval() |
| if self.llm is not None: |
| self.llm.eval() |
| return self |
|
|
| def _embed_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: |
| """Embed tokens using the frozen backbone's embedding table.""" |
| return self.backbone.model.embed_tokens(token_ids) |
|
|
| def _codec_to_speech_ids(self, codec_codes: torch.Tensor) -> torch.Tensor: |
| """Map NeuCodec codes [0, 65535] to neutts-nano speech token IDs.""" |
| return codec_codes + self.speech_token_offset |
|
|
| def _speech_ids_to_codec(self, speech_ids: torch.Tensor) -> torch.Tensor: |
| """Map neutts-nano speech token IDs back to NeuCodec codes [0, 65535].""" |
| return speech_ids - self.speech_token_offset |
|
|
| def forward( |
| self, |
| text_token_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| llm_hidden_states: Optional[torch.Tensor] = None, |
| codec_labels: Optional[torch.Tensor] = None, |
| codec_input_ids: Optional[torch.Tensor] = None, |
| codec_attention_mask: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> AudioHeadOutput: |
| """Forward pass for training or inference. |
| |
| Args: |
| text_token_ids: Text token IDs [batch, seq_len] (LLM tokenizer vocab). |
| Run through frozen LLM to get hidden states. Mutually exclusive |
| with llm_hidden_states. |
| attention_mask: Text attention mask [batch, seq_len] (1=real, 0=padding) |
| llm_hidden_states: Pre-computed LLM hidden states [batch, seq_len, llm_dim]. |
| Used in pipeline mode when ASRModel provides hidden states directly. |
| codec_labels: Target NeuCodec codes [batch, audio_len] (-100 for ignore) |
| codec_input_ids: Teacher-forced NeuCodec codes [batch, audio_len] |
| codec_attention_mask: Codec attention mask [batch, audio_len] |
| **kwargs: Absorbed silently (Trainer may pass extra keys). |
| |
| Returns: |
| AudioHeadOutput with loss (training) or codes (inference). |
| """ |
| |
| if llm_hidden_states is not None: |
| hidden_states = llm_hidden_states |
| elif text_token_ids is not None: |
| |
| |
| batch_size = text_token_ids.shape[0] |
| device = text_token_ids.device |
| prompt = self._prompt_prefix_ids.expand(batch_size, -1).to(device) |
| full_ids = torch.cat([prompt, text_token_ids], dim=1) |
|
|
| if attention_mask is not None: |
| prompt_mask = torch.ones( |
| batch_size, self._prompt_len, device=device, dtype=attention_mask.dtype |
| ) |
| full_mask = torch.cat([prompt_mask, attention_mask], dim=1) |
| else: |
| full_mask = None |
|
|
| with torch.no_grad(): |
| llm_out = self.llm.model( |
| input_ids=full_ids, |
| attention_mask=full_mask, |
| ) |
| |
| hidden_states = llm_out.last_hidden_state[:, self._prompt_len :] |
| else: |
| raise ValueError("Either text_token_ids or llm_hidden_states must be provided") |
|
|
| batch_size, text_len = hidden_states.shape[:2] |
| device = hidden_states.device |
|
|
| |
| |
| prefix = self.projector(hidden_states) |
|
|
| if codec_labels is None: |
| |
| codes = self._generate(prefix, attention_mask) |
| return AudioHeadOutput(codes=codes) |
|
|
| |
| assert codec_input_ids is not None, "codec_input_ids required when codec_labels provided" |
|
|
| |
| |
| |
| speech_input = self._map_collator_ids_to_speech(codec_input_ids) |
|
|
| with torch.no_grad(): |
| token_emb = self._embed_tokens(speech_input) |
|
|
| audio_len = token_emb.shape[1] |
|
|
| |
| |
| hidden = torch.cat([prefix, token_emb], dim=1) |
|
|
| |
| prefix_mask = ( |
| attention_mask |
| if attention_mask is not None |
| else torch.ones(batch_size, text_len, device=device, dtype=torch.long) |
| ) |
| audio_mask = ( |
| codec_attention_mask |
| if codec_attention_mask is not None |
| else torch.ones(batch_size, audio_len, device=device, dtype=torch.long) |
| ) |
| combined_mask = torch.cat([prefix_mask, audio_mask], dim=1) |
|
|
| |
| |
| |
| |
| outputs = self.backbone.model( |
| inputs_embeds=hidden, |
| attention_mask=combined_mask, |
| ) |
|
|
| |
| audio_hidden = outputs.last_hidden_state[:, text_len:] |
|
|
| |
| |
| |
| logits = self.backbone.lm_head(audio_hidden) |
|
|
| |
| speech_labels = self._map_collator_labels_to_speech(codec_labels) |
|
|
| |
| loss = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| speech_labels.view(-1), |
| ignore_index=-100, |
| ) |
|
|
| return AudioHeadOutput(loss=loss) |
|
|
| def _map_collator_ids_to_speech(self, codec_input_ids: torch.Tensor) -> torch.Tensor: |
| """Map S2SDataCollator codec_input_ids to neutts-nano token IDs. |
| |
| S2SDataCollator produces: |
| - BOS_TOKEN (65536) at position 0 |
| - NeuCodec codes (0-65535) for real audio |
| - PAD_TOKEN (65538) for padding |
| |
| Maps to: |
| - BOS_TOKEN → <|SPEECH_GENERATION_START|> |
| - codes 0-65535 → <|speech_0|>..<|speech_65535|> |
| - PAD_TOKEN → pad_token_id |
| """ |
| result = codec_input_ids.clone() |
|
|
| |
| bos_mask = codec_input_ids == NEUCODEC_VOCAB_SIZE |
| result[bos_mask] = self.speech_start_id |
|
|
| |
| eos_mask = codec_input_ids == (NEUCODEC_VOCAB_SIZE + 1) |
| result[eos_mask] = self.speech_end_id |
|
|
| |
| pad_mask = codec_input_ids == (NEUCODEC_VOCAB_SIZE + 2) |
| result[pad_mask] = self.tts_tokenizer.pad_token_id |
|
|
| |
| codec_mask = codec_input_ids < NEUCODEC_VOCAB_SIZE |
| result[codec_mask] = codec_input_ids[codec_mask] + self.speech_token_offset |
|
|
| return result |
|
|
| def _map_collator_labels_to_speech(self, codec_labels: torch.Tensor) -> torch.Tensor: |
| """Map S2SDataCollator codec_labels to neutts-nano token IDs. |
| |
| codec_labels contains: |
| - NeuCodec codes (0-65535) for real targets |
| - EOS_TOKEN (65537) at the end |
| - -100 for ignore positions |
| """ |
| result = codec_labels.clone() |
|
|
| valid = codec_labels != -100 |
|
|
| |
| eos_mask = valid & (codec_labels == (NEUCODEC_VOCAB_SIZE + 1)) |
| result[eos_mask] = self.speech_end_id |
|
|
| |
| codec_mask = valid & (codec_labels < NEUCODEC_VOCAB_SIZE) |
| result[codec_mask] = codec_labels[codec_mask] + self.speech_token_offset |
|
|
| return result |
|
|
| def _generate( |
| self, prefix: torch.Tensor, prefix_mask: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| """AR generation with KV cache on frozen backbone. |
| |
| Args: |
| prefix: Projected text embeddings [batch, text_len, 576]. |
| prefix_mask: Attention mask for prefix tokens (unused for now, |
| reserved for batched generation with padding). |
| """ |
| _ = prefix_mask |
| batch_size, text_len, _ = prefix.shape |
| device = prefix.device |
|
|
| all_codes = [] |
|
|
| |
| start_token = torch.full( |
| (batch_size, 1), self.speech_start_id, dtype=torch.long, device=device |
| ) |
| start_emb = self._embed_tokens(start_token) |
| hidden = torch.cat([prefix, start_emb], dim=1) |
|
|
| position_ids = torch.arange(text_len + 1, device=device).unsqueeze(0).expand(batch_size, -1) |
|
|
| |
| with torch.no_grad(): |
| outputs = self.backbone.model( |
| inputs_embeds=hidden, |
| position_ids=position_ids, |
| use_cache=True, |
| ) |
| past_key_values = outputs.past_key_values |
| last_hidden = outputs.last_hidden_state[:, -1:] |
|
|
| for step in range(self.max_tokens): |
| |
| logits = self.backbone.lm_head(last_hidden.squeeze(1)) |
|
|
| |
| speech_logits = logits[ |
| :, self.speech_token_offset : self.speech_token_offset + NEUCODEC_VOCAB_SIZE |
| ] |
|
|
| |
| end_logit = logits[:, self.speech_end_id : self.speech_end_id + 1] |
|
|
| |
| combined = torch.cat([speech_logits, end_logit], dim=-1) |
|
|
| |
| if self.temperature != 1.0: |
| combined = combined / self.temperature |
| if self.top_k > 0: |
| topk_vals, _ = combined.topk(min(self.top_k, combined.size(-1))) |
| combined[combined < topk_vals[:, -1:]] = float("-inf") |
|
|
| probs = F.softmax(combined, dim=-1) |
| sampled = torch.multinomial(probs, 1).squeeze(-1) |
|
|
| |
| is_eos = sampled == NEUCODEC_VOCAB_SIZE |
| if is_eos.all(): |
| break |
|
|
| |
| codec_code = sampled.clamp(0, NEUCODEC_VOCAB_SIZE - 1) |
| all_codes.append(codec_code) |
|
|
| |
| next_token_id = codec_code + self.speech_token_offset |
| |
| next_token_id[is_eos] = self.speech_end_id |
|
|
| next_emb = self._embed_tokens(next_token_id.unsqueeze(1)) |
|
|
| next_pos = torch.full( |
| (batch_size, 1), |
| text_len + 1 + step + 1, |
| dtype=torch.long, |
| device=device, |
| ) |
|
|
| outputs = self.backbone.model( |
| inputs_embeds=next_emb, |
| position_ids=next_pos, |
| past_key_values=past_key_values, |
| use_cache=True, |
| ) |
| past_key_values = outputs.past_key_values |
| last_hidden = outputs.last_hidden_state |
|
|
| if all_codes: |
| codes = torch.stack(all_codes, dim=1) |
| else: |
| codes = torch.empty(batch_size, 0, dtype=torch.long, device=device) |
|
|
| return codes |
|
|
| def state_dict(self, *args, **kwargs): |
| """Only save projector weights (backbone is frozen/pretrained).""" |
| full = super().state_dict(*args, **kwargs) |
| return {k: v for k, v in full.items() if k.startswith("projector.")} |
|
|
| def _load_neucodec(self): |
| """Load frozen NeuCodec model for audio decoding.""" |
| from neucodec import NeuCodec |
|
|
| self.neucodec_model = NeuCodec.from_pretrained(self.config.neucodec_model_id) |
| self.neucodec_model.eval() |
| self.neucodec_model.requires_grad_(False) |
| logger.info("Loaded frozen NeuCodec model for audio decoding") |
|
|
| def decode_to_audio(self, codes: torch.Tensor) -> list[torch.Tensor]: |
| """Decode NeuCodec FSQ tokens to audio waveforms. |
| |
| Args: |
| codes: Codec tokens [batch, seq_len] (values 0-65535) |
| |
| Returns: |
| List of audio waveform tensors (one per batch item) |
| """ |
| if self.neucodec_model is None: |
| self._load_neucodec() |
| assert self.neucodec_model is not None |
|
|
| codes_3d = codes.unsqueeze(1).to(self.neucodec_model.device) |
|
|
| with torch.no_grad(): |
| audio_values = self.neucodec_model.decode_code(codes_3d) |
|
|
| return [audio_values[i, 0] for i in range(audio_values.shape[0])] |
|
|
| def generate_streaming( |
| self, |
| text_token_ids: Optional[torch.Tensor] = None, |
| llm_hidden_states: Optional[torch.Tensor] = None, |
| chunk_samples: int = 24000, |
| ) -> Iterator[torch.Tensor]: |
| """Generate audio and yield waveform chunks for streaming playback.""" |
| output = self(text_token_ids=text_token_ids, llm_hidden_states=llm_hidden_states) |
| codes = output.codes |
| audios = self.decode_to_audio(codes) |
|
|
| for audio in audios: |
| for start in range(0, audio.shape[-1], chunk_samples): |
| end = min(start + chunk_samples, audio.shape[-1]) |
| yield audio[..., start:end] |
|
|