| """Flow matching audio head for speech-to-speech. |
| |
| Generates audio from LLM hidden states via flow matching: |
| LLM hidden -> llm_proj -> flow_net (LSD decode) -> Mimi latents -> Mimi decoder -> audio |
| |
| Supports two modes: |
| 1. Training from scratch with 512-dim Mimi embeddings (latent_proj_in/out) |
| 2. Using pretrained pocket-tts flow_net with 32-dim normalized latents |
| """ |
|
|
| import logging |
| from functools import partial |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from .modules.mlp import SimpleMLPAdaLN |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def lsd_decode( |
| v_t, |
| x_0: torch.Tensor, |
| num_steps: int = 1, |
| ) -> torch.Tensor: |
| """Lagrangian Self-Distillation decoding. |
| |
| Iteratively refines noise into latents using the flow velocity network. |
| |
| Args: |
| v_t: Velocity function v(s, t, x) -> velocity |
| x_0: Initial noise, shape [N, latent_dim] |
| num_steps: Number of integration steps |
| |
| Returns: |
| Decoded latents, shape [N, latent_dim] |
| """ |
| current = x_0 |
| for i in range(num_steps): |
| s = i / num_steps |
| t = (i + 1) / num_steps |
| s_tensor = torch.full_like(x_0[..., :1], s) |
| t_tensor = torch.full_like(x_0[..., :1], t) |
| flow_dir = v_t(s_tensor, t_tensor, current) |
| current = current + flow_dir / num_steps |
| return current |
|
|
|
|
| class AudioHead(nn.Module): |
| """Flow matching head: LLM hidden -> Mimi latents -> audio. |
| |
| Architecture: |
| - llm_proj: Linear projection from LLM hidden dim to flow conditioning |
| - latent_proj_in/out: Project between Mimi 512-dim and flow 32-dim |
| - flow_net: SimpleMLPAdaLN that predicts flow velocity |
| - Mimi decoder for latent -> audio |
| |
| Args: |
| config: ASRConfig with: |
| - llm_dim: LLM hidden dimension (default: 2048) |
| - lsd_decode_steps: Number of LSD integration steps (default: 1) |
| - flow_temperature: Sampling temperature for noise (default: 1.0) |
| """ |
|
|
| |
| COND_DIM = 1024 |
| LATENT_DIM = 32 |
| MIMI_DIM = 512 |
| FLOW_DIM = 512 |
| FLOW_DEPTH = 6 |
|
|
| def __init__(self, config, llm_dim: int = None): |
| super().__init__() |
| |
| self.llm_dim = llm_dim or getattr(config, "llm_dim", None) or 2048 |
| self.cond_dim = self.COND_DIM |
| self.latent_dim = self.LATENT_DIM |
| self.mimi_dim = self.MIMI_DIM |
| self.lsd_steps = getattr(config, "lsd_decode_steps", 1) |
| self.temp = getattr(config, "flow_temperature", 1.0) |
|
|
| |
| self.llm_proj = nn.Linear(self.llm_dim, self.cond_dim, bias=False) |
|
|
| |
| |
| self.latent_proj_in = nn.Linear(self.mimi_dim, self.latent_dim, bias=False) |
| self.latent_proj_out = nn.Linear(self.latent_dim, self.mimi_dim, bias=False) |
|
|
| |
| self.flow_net = SimpleMLPAdaLN( |
| in_channels=self.latent_dim, |
| model_channels=self.FLOW_DIM, |
| out_channels=self.latent_dim, |
| cond_channels=self.cond_dim, |
| num_res_blocks=self.FLOW_DEPTH, |
| num_time_conds=2, |
| ) |
|
|
| |
| |
| self.register_buffer("emb_mean", torch.zeros(self.latent_dim)) |
| self.register_buffer("emb_std", torch.ones(self.latent_dim)) |
| self._use_pretrained_normalization = False |
|
|
| |
| self.mimi = None |
|
|
| def load_mimi_decoder(self, device: torch.device = None, dtype: torch.dtype = None): |
| """Load Mimi model for decoding latents to audio.""" |
| from transformers import MimiModel |
|
|
| self.mimi = MimiModel.from_pretrained("kyutai/mimi") |
| self.mimi.requires_grad_(False) |
| self.mimi.eval() |
|
|
| if device is not None: |
| self.mimi = self.mimi.to(device) |
| if dtype is not None: |
| self.mimi = self.mimi.to(dtype) |
|
|
| logger.info("Loaded Mimi decoder from kyutai/mimi") |
|
|
| def load_pretrained_flow_net( |
| self, |
| weights_path: Optional[str] = None, |
| freeze: bool = True, |
| ): |
| """Load pretrained pocket-tts flow_net weights. |
| |
| This enables using the pretrained flow matching network from pocket-tts, |
| which operates in normalized 32-dim latent space. |
| |
| Args: |
| weights_path: Path to safetensors file. If None, downloads from HuggingFace. |
| freeze: Whether to freeze flow_net weights (default: True, only train llm_proj) |
| """ |
| import safetensors.torch |
|
|
| if weights_path is None: |
| from huggingface_hub import hf_hub_download |
|
|
| weights_path = hf_hub_download( |
| repo_id="kyutai/pocket-tts", filename="tts_b6369a24.safetensors" |
| ) |
|
|
| state = safetensors.torch.load_file(weights_path) |
|
|
| |
| flow_state = {} |
| for k, v in state.items(): |
| if k.startswith("flow_lm.flow_net."): |
| new_key = k.replace("flow_lm.flow_net.", "") |
| flow_state[new_key] = v |
|
|
| self.flow_net.load_state_dict(flow_state) |
| logger.info(f"Loaded pretrained flow_net from {weights_path}") |
|
|
| |
| if "flow_lm.emb_mean" in state: |
| self.emb_mean.copy_(state["flow_lm.emb_mean"]) |
| if "flow_lm.emb_std" in state: |
| self.emb_std.copy_(state["flow_lm.emb_std"]) |
| |
| self._use_pretrained_normalization = True |
| logger.info("Loaded emb_mean and emb_std for normalization") |
|
|
| if freeze: |
| self.flow_net.requires_grad_(False) |
| logger.info("Froze flow_net weights (only llm_proj will train)") |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| latent_targets: Optional[torch.Tensor] = None, |
| latent_lengths: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Forward pass for training or inference. |
| |
| Args: |
| hidden_states: LLM hidden states, shape [batch, seq_len, llm_dim] |
| latent_targets: Target Mimi latents for training, shape [batch, seq_len, 512] |
| latent_lengths: Actual lengths per sample, shape [batch] |
| |
| Returns: |
| Training: scalar flow matching loss |
| Inference: generated Mimi latents, shape [batch, seq_len, 512] |
| """ |
| |
| cond = self.llm_proj(hidden_states) |
|
|
| if latent_targets is not None: |
| return self._compute_loss(cond, latent_targets, latent_lengths) |
| return self._generate(cond) |
|
|
| def _compute_loss( |
| self, |
| cond: torch.Tensor, |
| targets: torch.Tensor, |
| lengths: Optional[torch.Tensor], |
| ) -> torch.Tensor: |
| """Compute flow matching loss with reconstruction term. |
| |
| The loss has two components: |
| 1. Flow matching loss: MSE between predicted and target velocities in 32-dim space |
| 2. Reconstruction loss: MSE between reconstructed and original 512-dim embeddings |
| (this ensures latent_proj_out is trained) |
| |
| Args: |
| cond: Conditioning from LLM, shape [batch, cond_seq_len, cond_dim] |
| targets: Mimi embeddings, shape [batch, target_seq_len, 512] |
| lengths: Optional lengths for masking |
| """ |
| |
| if torch.isnan(cond).any() or torch.isinf(cond).any(): |
| logger.warning( |
| f"NaN/Inf in cond! shape={cond.shape}, nan={torch.isnan(cond).sum()}, inf={torch.isinf(cond).sum()}" |
| ) |
| if torch.isnan(targets).any() or torch.isinf(targets).any(): |
| logger.warning(f"NaN/Inf in targets! shape={targets.shape}") |
|
|
| batch, cond_seq_len, _ = cond.shape |
| target_seq_len = targets.shape[1] |
| device = cond.device |
| dtype = cond.dtype |
|
|
| |
| if cond_seq_len == 0 or target_seq_len == 0: |
| return torch.tensor(0.0, device=device, dtype=dtype, requires_grad=True) |
|
|
| |
| targets_proj = self.latent_proj_in(targets) |
|
|
| |
| |
| targets_reconstructed = self.latent_proj_out(targets_proj) |
|
|
| |
| targets_for_interp = targets |
| if target_seq_len != cond_seq_len: |
| targets_proj = targets_proj.transpose(1, 2) |
| targets_proj = torch.nn.functional.interpolate( |
| targets_proj, size=cond_seq_len, mode="linear", align_corners=False |
| ) |
| targets_proj = targets_proj.transpose(1, 2).contiguous() |
|
|
| |
| targets_for_interp = targets.transpose(1, 2) |
| targets_for_interp = torch.nn.functional.interpolate( |
| targets_for_interp, size=cond_seq_len, mode="linear", align_corners=False |
| ) |
| targets_for_interp = targets_for_interp.transpose(1, 2).contiguous() |
|
|
| |
| targets_reconstructed = targets_reconstructed.transpose(1, 2) |
| targets_reconstructed = torch.nn.functional.interpolate( |
| targets_reconstructed, size=cond_seq_len, mode="linear", align_corners=False |
| ) |
| targets_reconstructed = targets_reconstructed.transpose(1, 2).contiguous() |
|
|
| if lengths is not None: |
| scale = cond_seq_len / target_seq_len |
| lengths = (lengths.float() * scale).long() |
|
|
| seq_len = cond_seq_len |
| x_1 = targets_proj |
|
|
| |
| t = torch.rand(batch, seq_len, 1, device=device, dtype=dtype) |
|
|
| |
| x_0 = torch.randn_like(x_1) |
|
|
| |
| x_t = (1 - t) * x_0 + t * x_1 |
|
|
| |
| v_target = x_1 - x_0 |
|
|
| |
| cond_flat = cond.view(-1, self.cond_dim) |
| t_flat = t.view(-1, 1) |
| x_t_flat = x_t.view(-1, self.latent_dim) |
|
|
| |
| v_pred = self.flow_net(cond_flat, t_flat, t_flat, x_t_flat) |
| v_pred = v_pred.view(batch, seq_len, self.latent_dim) |
|
|
| |
| if lengths is not None: |
| positions = torch.arange(seq_len, device=device).unsqueeze(0) |
| mask = positions < lengths.unsqueeze(1) |
|
|
| |
| if not mask.any(): |
| return torch.tensor(0.0, device=device, dtype=dtype, requires_grad=True) |
|
|
| flow_mask = mask.unsqueeze(-1).expand_as(v_pred) |
| recon_mask = mask.unsqueeze(-1).expand_as(targets_reconstructed) |
|
|
| flow_loss = ((v_pred - v_target) ** 2)[flow_mask].mean() |
| recon_loss = ((targets_reconstructed - targets_for_interp) ** 2)[recon_mask].mean() |
| else: |
| flow_loss = ((v_pred - v_target) ** 2).mean() |
| recon_loss = ((targets_reconstructed - targets_for_interp) ** 2).mean() |
|
|
| |
| return flow_loss + 0.1 * recon_loss |
|
|
| def _generate(self, cond: torch.Tensor) -> torch.Tensor: |
| """Generate Mimi embeddings via LSD decoding. |
| |
| Args: |
| cond: Conditioning from LLM, shape [batch, seq_len, cond_dim] |
| |
| Returns: |
| Generated Mimi embeddings, shape [batch, seq_len, 512] |
| """ |
| batch, seq_len, _ = cond.shape |
| device = cond.device |
| dtype = cond.dtype |
|
|
| |
| if seq_len == 0: |
| return torch.empty(batch, 0, self.mimi_dim, device=device, dtype=dtype) |
|
|
| |
| temp = max(0.0, self.temp) |
|
|
| latents = [] |
| for t in range(seq_len): |
| cond_t = cond[:, t] |
|
|
| |
| noise = torch.randn(batch, self.latent_dim, device=device, dtype=dtype) |
| noise = noise * (temp**0.5) |
|
|
| def velocity_fn(cond_fixed, s, t, x): |
| return self.flow_net(cond_fixed, s, t, x) |
|
|
| conditioned_flow = partial(velocity_fn, cond_t) |
| latent = lsd_decode(conditioned_flow, noise, self.lsd_steps) |
| latents.append(latent) |
|
|
| latents = torch.stack(latents, dim=1) |
|
|
| |
| if self._use_pretrained_normalization: |
| latents = latents * self.emb_std + self.emb_mean |
|
|
| |
| return self.latent_proj_out(latents) |
|
|
| def decode_to_audio(self, latents: torch.Tensor) -> torch.Tensor: |
| """Decode Mimi latents to audio waveform. |
| |
| Note: HuggingFace MimiModel.decode() expects discrete codes, not continuous |
| embeddings. We bypass the quantizer and call upsample → decoder_transformer |
| → decoder directly to decode from continuous latents. |
| |
| Args: |
| latents: Mimi latents, shape [batch, seq_len, 512] |
| |
| Returns: |
| Audio waveform, shape [batch, samples] |
| """ |
| if self.mimi is None: |
| raise RuntimeError("Mimi decoder not loaded. Call load_mimi_decoder() first.") |
|
|
| |
| latents = latents.transpose(1, 2) |
|
|
| with torch.no_grad(): |
| |
| emb = self.mimi.upsample(latents) |
|
|
| |
| emb = emb.transpose(1, 2) |
| decoder_out = self.mimi.decoder_transformer(emb) |
| emb = getattr(decoder_out, "last_hidden_state", decoder_out[0]) |
|
|
| |
| emb = emb.transpose(1, 2) |
| audio = self.mimi.decoder(emb) |
|
|
| return audio.squeeze(1) |
|
|
| def get_output_length(self, input_length: int) -> int: |
| """Estimate output audio frames from input hidden state length. |
| |
| For Mimi at 12.5 Hz frame rate with 24kHz audio: |
| Each latent frame = 24000 / 12.5 = 1920 audio samples |
| """ |
| return input_length * 1920 |
|
|