Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2025 Xiaomi Corporation. | |
| import math | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from flash_attn import flash_attn_varlen_func | |
| from torch.nn import functional as F | |
| from transformers.activations import ACT2FN | |
| from transformers.modeling_utils import PreTrainedModel | |
| from .configuration_audio_tokenizer import MiMoAudioTokenizerConfig | |
| from .modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update, apply_rotary_pos_emb | |
| from .quantization import ResidualVectorQuantizer | |
| from dataclasses import dataclass, field | |
| from typing import List | |
| def get_sequence_mask(inputs, inputs_length): | |
| if inputs.dim() == 3: | |
| bsz, tgt_len, _ = inputs.size() | |
| else: | |
| bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length) | |
| sequence_mask = torch.arange(0, tgt_len).to(inputs.device) | |
| sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view( | |
| bsz, tgt_len, 1 | |
| ) | |
| unpacking_index = torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1 | |
| return sequence_mask, unpacking_index | |
| def unpack_hidden_states( | |
| hidden_states, lengths, sequence_mask=None, unpacking_index=None | |
| ): | |
| bsz = lengths.shape[0] | |
| if sequence_mask is None or unpacking_index is None: | |
| sequence_mask, unpacking_index = get_sequence_mask(hidden_states, lengths) | |
| hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view( | |
| bsz, torch.max(lengths), hidden_states.shape[-1] | |
| ) | |
| hidden_states = torch.where(sequence_mask, hidden_states, 0) | |
| return hidden_states | |
| def get_position_ids(lengths): | |
| total_len = lengths.sum() | |
| offset = torch.cat([torch.zeros(1).to(lengths), lengths[:-1].cumsum(dim=0)]) | |
| offset = torch.repeat_interleave(offset, lengths) | |
| position_ids = torch.arange(0, total_len).to(offset) - offset | |
| return position_ids | |
| class StreamingConfig: | |
| seg_point: int = field(default=60 * 25) | |
| process_seg_point: bool = field(default=True) | |
| left_overlap: int = field(default=10 * 25) | |
| right_overlap: int = field(default=40) | |
| seg_point_left_overlap: int = field(default=0) | |
| class StreamingCache: | |
| hidden_states: List[torch.Tensor] = field(default=None) | |
| processed_lengths: List[int] = field(default=None) | |
| class ISTFT(nn.Module): | |
| """ | |
| Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with | |
| windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. | |
| See issue: https://github.com/pytorch/pytorch/issues/62323 | |
| Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. | |
| The NOLA constraint is met as we trim padded samples anyway. | |
| Args: | |
| n_fft (int): Size of Fourier transform. | |
| hop_length (int): The distance between neighboring sliding window frames. | |
| win_length (int): The size of window frame and STFT filter. | |
| padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". | |
| """ | |
| def __init__( | |
| self, n_fft: int, hop_length: int, win_length: int, padding: str = "same" | |
| ): | |
| super().__init__() | |
| if padding not in ["center", "same"]: | |
| raise ValueError("Padding must be 'center' or 'same'.") | |
| self.padding = padding | |
| self.n_fft = n_fft | |
| self.hop_length = hop_length | |
| self.win_length = win_length | |
| window = torch.hann_window(win_length) | |
| self.register_buffer("window", window) | |
| def forward(self, spec: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. | |
| Args: | |
| spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, | |
| N is the number of frequency bins, and T is the number of time frames. | |
| Returns: | |
| Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. | |
| """ | |
| if self.padding == "center": | |
| # Fallback to pytorch native implementation | |
| return torch.istft( | |
| spec, | |
| self.n_fft, | |
| self.hop_length, | |
| self.win_length, | |
| self.window, | |
| center=True, | |
| ) | |
| elif self.padding == "same": | |
| pad = (self.win_length - self.hop_length) // 2 | |
| else: | |
| raise ValueError("Padding must be 'center' or 'same'.") | |
| assert spec.dim() == 3, "Expected a 3D tensor as input" | |
| B, N, T = spec.shape | |
| # Inverse FFT | |
| ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") | |
| ifft = ifft * self.window[None, :, None] | |
| # Overlap and Add | |
| output_size = (T - 1) * self.hop_length + self.win_length | |
| y = torch.nn.functional.fold( | |
| ifft, | |
| output_size=(1, output_size), | |
| kernel_size=(1, self.win_length), | |
| stride=(1, self.hop_length), | |
| )[:, 0, 0, pad:-pad] | |
| # Window envelope | |
| window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) | |
| window_envelope = torch.nn.functional.fold( | |
| window_sq, | |
| output_size=(1, output_size), | |
| kernel_size=(1, self.win_length), | |
| stride=(1, self.hop_length), | |
| ).squeeze()[pad:-pad] | |
| # Normalize | |
| assert (window_envelope > 1e-11).all() | |
| y = y / window_envelope | |
| return y | |
| class ISTFTHead(nn.Module): | |
| """ | |
| ISTFT Head module for predicting STFT complex coefficients. | |
| Args: | |
| dim (int): Hidden dimension of the model. | |
| n_fft (int): Size of Fourier transform. | |
| hop_length (int): The distance between neighboring sliding window frames, which should align with | |
| the resolution of the input features. | |
| padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". | |
| """ | |
| def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): | |
| super().__init__() | |
| out_dim = n_fft + 2 | |
| self.out = torch.nn.Linear(dim, out_dim) | |
| self.istft = ISTFT( | |
| n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward pass of the ISTFTHead module. | |
| Args: | |
| x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, | |
| L is the sequence length, and H denotes the model dimension. | |
| Returns: | |
| Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. | |
| """ | |
| x = self.out(x).transpose(1, 2) | |
| mag, p = x.chunk(2, dim=1) | |
| mag = torch.exp(mag) | |
| mag = torch.clip( | |
| mag, max=1e2 | |
| ) # safeguard to prevent excessively large magnitudes | |
| # wrapping happens here. These two lines produce real and imaginary value | |
| x = torch.cos(p) | |
| y = torch.sin(p) | |
| # recalculating phase here does not produce anything new | |
| # only costs time | |
| # phase = torch.atan2(y, x) | |
| # S = mag * torch.exp(phase * 1j) | |
| # better directly produce the complex value | |
| original_dtype = x.dtype | |
| S = mag.float() * (x.float() + 1j * y.float()) | |
| audio = self.istft(S) | |
| audio = audio.to(original_dtype) | |
| return audio | |
| class RotaryEmbedding(nn.Module): | |
| def __init__(self, base, dim, max_seq_len, rope_type="default", device=None): | |
| super().__init__() | |
| self.max_seq_len = max_seq_len | |
| self.rope_type = rope_type | |
| self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] | |
| inv_freq, self.attention_scaling = self.rope_init_fn( | |
| device=device, base=base, dim=dim | |
| ) | |
| self.register_buffer("inv_freq", inv_freq, persistent=False) | |
| self.original_inv_freq = self.inv_freq | |
| def forward(self, x, position_ids): | |
| inv_freq_expanded = self.inv_freq[:, None].float().expand(-1, 1).to(x.device) | |
| position_ids_expanded = position_ids[None, :].float() | |
| device_type = ( | |
| x.device.type | |
| if isinstance(x.device.type, str) and x.device.type != "mps" | |
| else "cpu" | |
| ) | |
| with torch.autocast(device_type=device_type, enabled=False): # Force float32 | |
| freqs = ( | |
| inv_freq_expanded.float() @ position_ids_expanded.float() | |
| ).transpose(0, 1) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| cos = emb.cos() * self.attention_scaling | |
| sin = emb.sin() * self.attention_scaling | |
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) | |
| class RMSNorm(nn.Module): | |
| def __init__(self, hidden_size, eps=1e-6): | |
| """ | |
| RMSNorm is equivalent to T5LayerNorm | |
| """ | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |
| self.variance_epsilon = eps | |
| def forward(self, hidden_states): | |
| variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) | |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) | |
| if self.weight.dtype in [torch.float16, torch.bfloat16]: | |
| hidden_states = hidden_states.to(self.weight.dtype) | |
| return self.weight * hidden_states | |
| LAYER_NORM = {"LayerNorm": nn.LayerNorm, "RMSNorm": RMSNorm} | |
| class Attention(nn.Module): | |
| def __init__(self, embed_dim, num_heads, window_size=(-1, -1), causal=False): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.head_dim = embed_dim // num_heads | |
| self.window_size = window_size | |
| self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) | |
| self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True) | |
| self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True) | |
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) | |
| self.causal = causal | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| seq_len: torch.Tensor, | |
| rope_position_embeddings=None, | |
| ): | |
| bsz, _ = hidden_states.size() | |
| query_states = self.q_proj(hidden_states).view( | |
| bsz, self.num_heads, self.head_dim | |
| ) | |
| key_states = self.k_proj(hidden_states).view(bsz, self.num_heads, self.head_dim) | |
| value_states = self.v_proj(hidden_states).view( | |
| bsz, self.num_heads, self.head_dim | |
| ) | |
| if rope_position_embeddings is not None: | |
| cos, sin = rope_position_embeddings | |
| query_states = apply_rotary_pos_emb(query_states, cos, sin) | |
| key_states = apply_rotary_pos_emb(key_states, cos, sin) | |
| cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to( | |
| torch.int32 | |
| ) | |
| max_seqlen = torch.max(seq_len).to(torch.int32).detach() | |
| attn_output = flash_attn_varlen_func( | |
| query_states, | |
| key_states, | |
| value_states, | |
| cu_len, | |
| cu_len, | |
| max_seqlen, | |
| max_seqlen, | |
| causal=self.causal, | |
| window_size=self.window_size, | |
| ) | |
| attn_output = attn_output.reshape(bsz, self.embed_dim) | |
| attn_output = self.out_proj(attn_output) | |
| return attn_output | |
| class TransformerLayer(nn.Module): | |
| def __init__( | |
| self, | |
| act, | |
| d_model, | |
| encoder_attention_heads, | |
| encoder_ffn_dim, | |
| causal, | |
| ln_type="LayerNorm", | |
| attn_window_size=(-1, -1), | |
| ): | |
| super().__init__() | |
| self.embed_dim = d_model | |
| self.self_attn = Attention( | |
| self.embed_dim, encoder_attention_heads, attn_window_size, causal | |
| ) | |
| self.self_attn_layer_norm = LAYER_NORM[ln_type](self.embed_dim) | |
| self.activation_fn = act | |
| self.fc1 = nn.Linear(self.embed_dim, encoder_ffn_dim) | |
| self.fc2 = nn.Linear(encoder_ffn_dim, self.embed_dim) | |
| self.final_layer_norm = LAYER_NORM[ln_type](self.embed_dim) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| seq_len: torch.Tensor, | |
| rope_position_embeddings: torch.Tensor, | |
| ) -> torch.Tensor: | |
| residual = hidden_states | |
| hidden_states = self.self_attn_layer_norm(hidden_states) | |
| hidden_states = self.self_attn( | |
| hidden_states, seq_len, rope_position_embeddings=rope_position_embeddings | |
| ) | |
| hidden_states = residual + hidden_states | |
| residual = hidden_states | |
| hidden_states = self.final_layer_norm(hidden_states) | |
| hidden_states = self.activation_fn(self.fc1(hidden_states)) | |
| hidden_states = self.fc2(hidden_states) | |
| hidden_states = residual + hidden_states | |
| if ( | |
| hidden_states.dtype == torch.float16 | |
| or hidden_states.dtype == torch.bfloat16 | |
| ) and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()): | |
| clamp_value = torch.finfo(hidden_states.dtype).max - 1000 | |
| hidden_states = torch.clamp( | |
| hidden_states, min=-clamp_value, max=clamp_value | |
| ) | |
| return hidden_states | |
| class TransformerVocos(nn.Module): | |
| def __init__(self, config: MiMoAudioTokenizerConfig): | |
| super().__init__() | |
| self.config = config | |
| self.max_source_positions = ( | |
| self.config.max_audio_seconds | |
| * self.config.sampling_rate | |
| // self.config.hop_length | |
| ) | |
| self.embeddings = nn.Linear(config.n_mels, config.vocoder_dim, bias=False) | |
| self.poisition_embedding = RotaryEmbedding( | |
| config.rope_theta, | |
| config.vocoder_dim // config.vocoder_attention_heads, | |
| self.max_source_positions, | |
| self.config.rope_type, | |
| ) | |
| self.layers = nn.ModuleList( | |
| [ | |
| TransformerLayer( | |
| ACT2FN[self.config.activation_function], | |
| self.config.vocoder_dim, | |
| self.config.vocoder_attention_heads, | |
| self.config.vocoder_intermediate_dim, | |
| causal=False, | |
| ln_type=self.config.ln_type, | |
| attn_window_size=self.config.vocoder_attn_window_size, | |
| ) | |
| for _ in range(self.config.vocoder_num_layers) | |
| ] | |
| ) | |
| self.layer_norm = LAYER_NORM[self.config.ln_type](self.config.vocoder_dim) | |
| self.hop_size = self.config.hop_length | |
| self.head = ISTFTHead( | |
| self.config.vocoder_dim, | |
| self.config.nfft, | |
| self.config.hop_length, | |
| self.config.vocoder_padding, | |
| ) | |
| def forward(self, x: torch.Tensor, input_length): | |
| x = x.transpose(1, 2) | |
| attention_mask, unpacking_index = get_sequence_mask(x, input_length) | |
| x = torch.masked_select(x, attention_mask).view( | |
| torch.sum(input_length), self.config.n_mels | |
| ) | |
| x = self.embeddings(x) | |
| position_ids = torch.arange(0, x.size(0), device=x.device, dtype=torch.long) | |
| rope_position_embeddings = self.poisition_embedding(x, position_ids) | |
| for idx, layer in enumerate(self.layers): | |
| x = layer( | |
| x, input_length, rope_position_embeddings=rope_position_embeddings | |
| ) | |
| x = self.layer_norm(x) | |
| x = unpack_hidden_states(x, input_length, attention_mask, unpacking_index) | |
| x = self.head(x) | |
| output_length = input_length * self.hop_size | |
| return x[:, None, :], output_length | |
| class AudioEncoder(nn.Module): | |
| def __init__(self, config: MiMoAudioTokenizerConfig): | |
| super().__init__() | |
| config._attn_implementation = "flash_attention_2" | |
| self.config = config | |
| self.max_source_positions = ( | |
| config.max_audio_seconds * config.sampling_rate // config.hop_length | |
| ) // config.stride_size | |
| self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 | |
| self.skip_layer_idx = config.encoder_skip_layer_id | |
| self.conv1 = nn.Conv1d( | |
| config.n_mels, config.d_model, kernel_size=config.kernel_size, padding=1 | |
| ) | |
| self.conv2 = nn.Conv1d( | |
| config.d_model, | |
| config.d_model, | |
| kernel_size=config.kernel_size, | |
| stride=config.stride_size, | |
| padding=1, | |
| ) | |
| self.position_embedding = RotaryEmbedding( | |
| config.rope_theta, | |
| config.d_model // config.encoder_attention_heads, | |
| self.max_source_positions, | |
| config.rope_type, | |
| ) | |
| self.layers = nn.ModuleList( | |
| [ | |
| TransformerLayer( | |
| ACT2FN[config.activation_function], | |
| config.d_model, | |
| config.encoder_attention_heads, | |
| config.encoder_ffn_dim, | |
| causal=self.config.encoder_causal, | |
| ln_type=self.config.ln_type, | |
| attn_window_size=self.config.encoder_attn_window_size, | |
| ) | |
| for _ in range(config.encoder_layers) | |
| ] | |
| ) | |
| self.layer_norm = LAYER_NORM[config.ln_type](config.d_model) | |
| if self.config.avg_pooler != 1: | |
| self.down_sample_layer = nn.Sequential( | |
| nn.Conv1d( | |
| config.d_model, | |
| config.d_model, | |
| config.avg_pooler, | |
| config.avg_pooler, | |
| bias=False, | |
| ), | |
| nn.GELU(), | |
| ) | |
| self.down_sample_norm = LAYER_NORM[config.ln_type](config.d_model) | |
| else: | |
| self.down_sample_layer = None | |
| if self.config.num_quantizers != 0: | |
| self.quantizer = ResidualVectorQuantizer( | |
| dimension=self.config.d_model, | |
| n_q=self.config.num_quantizers, | |
| bins=self.config.codebook_size, | |
| threshold_ema_dead_code=self.config.threshold_ema_dead_code, | |
| ) | |
| else: | |
| self.quantizer = None | |
| def get_features(self, input_features, output_length): | |
| input_features = input_features.to(self.conv1.weight) | |
| inputs_embeds = nn.functional.gelu(self.conv1(input_features)) | |
| inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) | |
| inputs_embeds = inputs_embeds.permute(0, 2, 1) | |
| bsz, tgt_len, _ = inputs_embeds.size() | |
| hidden_states = inputs_embeds | |
| position_ids = ( | |
| get_position_ids(output_length).long().to(input_features.device) | |
| ) | |
| rope_position_embeddings = self.position_embedding( | |
| input_features, position_ids | |
| ) | |
| attention_mask, unpacking_index = get_sequence_mask( | |
| hidden_states, output_length | |
| ) | |
| hidden_states = torch.masked_select(hidden_states, attention_mask).view( | |
| torch.sum(output_length), self.config.d_model | |
| ) | |
| skip_connect_hidden_states = 0.0 | |
| for idx, encoder_layer in enumerate(self.layers): | |
| hidden_states = encoder_layer( | |
| hidden_states, | |
| output_length, | |
| rope_position_embeddings=rope_position_embeddings, | |
| ) | |
| if (self.skip_layer_idx is not None) and idx == self.skip_layer_idx - 1: | |
| skip_connect_hidden_states = hidden_states.clone() | |
| hidden_states += skip_connect_hidden_states | |
| hidden_states = self.layer_norm(hidden_states) | |
| if self.down_sample_layer is not None: | |
| hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view( | |
| bsz, tgt_len, self.config.d_model | |
| ) | |
| if hidden_states.size(1) % self.config.avg_pooler: | |
| pad_len = ( | |
| self.config.avg_pooler | |
| - hidden_states.size(1) % self.config.avg_pooler | |
| ) | |
| hidden_states = torch.nn.functional.pad( | |
| hidden_states, (0, 0, 0, pad_len), mode="constant", value=0.0 | |
| ) | |
| tgt_len += pad_len | |
| tgt_len = tgt_len // self.config.avg_pooler | |
| hidden_states = self.down_sample_layer(hidden_states.transpose(1, 2)) | |
| output_length = ( | |
| output_length // self.config.avg_pooler | |
| + (output_length % self.config.avg_pooler != 0).int() | |
| ) | |
| hidden_states = hidden_states.transpose(1, 2) | |
| attention_mask, unpacking_index = get_sequence_mask( | |
| hidden_states, output_length | |
| ) | |
| hidden_states = torch.masked_select(hidden_states, attention_mask).view( | |
| torch.sum(output_length), self.config.d_model | |
| ) | |
| hidden_states = self.down_sample_norm(hidden_states) | |
| return ( | |
| hidden_states, | |
| output_length, | |
| attention_mask, | |
| unpacking_index, | |
| tgt_len, | |
| bsz, | |
| ) | |
| def get_output_length(self, mel_len): | |
| tgt_len = mel_len + 3 - self.config.kernel_size | |
| return (tgt_len + 2 - self.config.kernel_size) // self.config.stride_size + 1 | |
| def encode( | |
| self, | |
| input_features, | |
| input_lens=None, | |
| output_length=None, | |
| return_codes_only=False, | |
| n_q=None, | |
| use_quantizer=True, | |
| ): | |
| if output_length is None: | |
| output_length = self.get_output_length(input_lens) | |
| input_features = unpack_hidden_states(input_features, input_lens) | |
| hidden_states, output_length, attention_mask, unpacking_index, tgt_len, bsz = ( | |
| self.get_features( | |
| input_features=input_features.transpose(1, 2), | |
| output_length=output_length, | |
| ) | |
| ) | |
| dtype = hidden_states.dtype | |
| if use_quantizer and self.quantizer is not None: | |
| self.quantizer.float() | |
| codes = self.quantizer.encode(hidden_states.float(), n_q=n_q) | |
| if return_codes_only: | |
| return codes, output_length | |
| hidden_states = self.quantizer.decode(codes) | |
| hidden_states = hidden_states.to(dtype) | |
| else: | |
| codes = None | |
| hidden_states_packed = hidden_states.clone() | |
| # unpacking | |
| hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view( | |
| bsz, tgt_len, self.config.d_model | |
| ) | |
| hidden_states = torch.where(attention_mask, hidden_states, 0) | |
| return hidden_states, hidden_states_packed, output_length, codes | |
| def decode_vq(self, codes): | |
| self.quantizer.float() | |
| hidden_states = self.quantizer.decode(codes) | |
| return hidden_states | |
| class CausalConvTranspose1d(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride): | |
| super().__init__() | |
| self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride) | |
| self.norm = nn.GroupNorm(1, out_channels) | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| def forward(self, hidden_states, input_length, output_dim=None): | |
| kernel_size = self.conv.kernel_size[0] | |
| stride = self.conv.stride[0] | |
| bsz = input_length.shape[0] | |
| if output_dim is None: | |
| output_dim = hidden_states.dim() | |
| if hidden_states.dim() <= 2: # unpack sequence to 3d | |
| sequence_mask, unpacking_index = get_sequence_mask( | |
| hidden_states, input_length | |
| ) | |
| hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view( | |
| bsz, torch.max(input_length), self.in_channels | |
| ) | |
| hidden_states = torch.where(sequence_mask, hidden_states, 0) | |
| hidden_states = hidden_states.transpose(2, 1) # (N, L, C) -> (N, C, L) | |
| hidden_states = self.conv(hidden_states) | |
| hidden_states = self.norm(hidden_states) | |
| hidden_states = hidden_states.transpose(2, 1) # (N, C, L) -> (N, L, C) | |
| casual_padding_right = max(0, kernel_size - stride) | |
| hidden_states = hidden_states[ | |
| :, : hidden_states.shape[1] - casual_padding_right, : | |
| ] | |
| output_length = (input_length - 1) * stride + kernel_size - casual_padding_right | |
| sequence_mask, _ = get_sequence_mask(hidden_states, output_length) | |
| if output_dim <= 2: | |
| hidden_states = torch.masked_select(hidden_states, sequence_mask).view( | |
| -1, self.out_channels | |
| ) | |
| else: | |
| hidden_states = torch.where(sequence_mask, hidden_states, 0) | |
| hidden_states = hidden_states[:, : torch.max(output_length), :] | |
| return hidden_states, output_length | |
| class AudioDecoder(nn.Module): | |
| def __init__(self, config: MiMoAudioTokenizerConfig): | |
| super().__init__() | |
| self.config = config | |
| self.max_source_positions = ( | |
| self.config.max_audio_seconds | |
| * self.config.sampling_rate | |
| // self.config.hop_length | |
| ) | |
| if self.config.avg_pooler != 1: | |
| self.dconv1 = CausalConvTranspose1d( | |
| self.config.d_model, | |
| self.config.d_model, | |
| self.config.avg_pooler, | |
| self.config.avg_pooler, | |
| ) | |
| else: | |
| self.dconv1 = None | |
| self.position_embedding = RotaryEmbedding( | |
| config.rope_theta, | |
| config.d_model // config.decoder_attention_heads, | |
| self.max_source_positions, | |
| config.rope_type, | |
| ) | |
| self.layers = nn.ModuleList( | |
| [ | |
| TransformerLayer( | |
| ACT2FN[self.config.activation_function], | |
| self.config.d_model, | |
| self.config.decoder_attention_heads, | |
| self.config.decoder_ffn_dim, | |
| causal=self.config.decoder_causal, | |
| ln_type=self.config.ln_type, | |
| attn_window_size=self.config.decoder_attn_window_size, | |
| ) | |
| for _ in range(self.config.decoder_layers) | |
| ] | |
| ) | |
| self.layer_norm = LAYER_NORM[config.ln_type](self.config.d_model) | |
| self.dconv2 = CausalConvTranspose1d( | |
| self.config.d_model, | |
| self.config.n_mels, | |
| self.config.decoder_kernel_size, | |
| self.config.decoder_stride_size, | |
| ) | |
| self.vocoder = TransformerVocos(config) | |
| def forward( | |
| self, | |
| audio_embed, | |
| input_length, | |
| ): | |
| assert audio_embed.shape[-1] == self.config.d_model | |
| audio_embed = audio_embed.to(self.layer_norm.weight) | |
| if self.dconv1 is not None: | |
| audio_embed, output_length = self.dconv1( | |
| audio_embed, input_length, output_dim=3 | |
| ) | |
| _, tgt_len, _ = audio_embed.size() | |
| else: | |
| output_length = input_length | |
| tgt_len = audio_embed.size(0) | |
| hidden_states = audio_embed | |
| position_ids = ( | |
| get_position_ids(output_length).long().to(hidden_states.device) | |
| ) | |
| rope_position_embeddings = self.position_embedding( | |
| hidden_states, position_ids | |
| ) | |
| # packing hidden states | |
| attention_mask, _ = get_sequence_mask(hidden_states, output_length) | |
| hidden_states = torch.masked_select(hidden_states, attention_mask).view( | |
| torch.sum(output_length), self.config.d_model | |
| ) | |
| for idx, encoder_layer in enumerate(self.layers): | |
| hidden_states = encoder_layer( | |
| hidden_states, | |
| output_length, | |
| rope_position_embeddings=rope_position_embeddings, | |
| ) | |
| hidden_states = self.layer_norm(hidden_states) | |
| coarse_mel, output_length = self.dconv2( | |
| hidden_states, output_length, output_dim=3 | |
| ) | |
| recon_wav, wav_length = self.vocoder( | |
| x=coarse_mel.transpose(1, 2), | |
| input_length=output_length, | |
| ) | |
| return recon_wav | |
| class MiMoAudioTokenizer(PreTrainedModel): | |
| config_class = MiMoAudioTokenizerConfig | |
| def __init__(self, config: MiMoAudioTokenizerConfig): | |
| super().__init__(config) | |
| self.config = config | |
| self.sampling_rate = config.sampling_rate | |
| self.encoder = AudioEncoder(config=config) | |
| self.decoder = AudioDecoder(config=config) | |
| self.downsample_rate = int(self.config.hop_length * 2 * self.config.avg_pooler) | |
| def get_output_length(self, mel_len): | |
| tgt_len = mel_len + 3 - self.config.kernel_size | |
| return (tgt_len + 2 - self.config.kernel_size) // self.config.stride_size + 1 | |
| def encode(self, mels, input_lens, use_quantizer=True): | |
| input_features = mels | |
| encoder_output_length = self.get_output_length(input_lens) | |
| hidden_states, hidden_states_packed, encoder_output_length, codes = ( | |
| self.encoder.encode( | |
| input_features, input_lens=input_lens, use_quantizer=use_quantizer | |
| ) | |
| ) | |
| return hidden_states, hidden_states_packed, encoder_output_length, codes | |
| def decode(self, codes): | |
| hidden_states = self.encoder.decode_vq(codes) | |
| output = self.decoder( | |
| hidden_states, | |
| torch.tensor([hidden_states.size(0)], device=hidden_states.device), | |
| ) | |
| return output | |
| def streaming_decode(self, codes_chunks, chunk_input_lengths, history_cache=StreamingCache(), streaming_config=StreamingConfig(), last_chunk=False): | |
| hidden_states = self.encoder.decode_vq(codes_chunks) | |
| input_lengths = [] | |
| input_hidden_states = [] | |
| start_idx = 0 | |
| cache_hidden_states = [] | |
| for i, input_length in enumerate(chunk_input_lengths): | |
| sample_hidden_states = hidden_states[start_idx:start_idx + input_length] | |
| start_idx += input_length | |
| if history_cache.hidden_states is not None: | |
| sample_hidden_states = torch.cat([history_cache.hidden_states[i], sample_hidden_states], dim=0) | |
| input_length += history_cache.hidden_states[i].size(0) | |
| input_hidden_states.append(sample_hidden_states) | |
| cache_hidden_states.append(sample_hidden_states.clone()) | |
| input_lengths.append(input_length) | |
| input_hidden_states = torch.cat(input_hidden_states, dim=0) | |
| input_lengths = torch.tensor(input_lengths, device=hidden_states.device) | |
| output = self.decoder(input_hidden_states, input_lengths) | |
| return_wavs = [] | |
| frames_per_token = self.config.avg_pooler * self.config.stride_size * self.config.hop_length | |
| processed_lengths = [] | |
| for i, wav in enumerate(output): | |
| wav = wav.float().detach().cpu() | |
| start_idx = history_cache.processed_lengths[i] if history_cache.processed_lengths is not None else 0 | |
| if last_chunk: | |
| return_wavs.append(wav[:, start_idx * frames_per_token:]) | |
| new_processed_length = input_lengths[i].item() | |
| elif input_lengths[i].item() <= streaming_config.right_overlap: | |
| return_wavs.append(None) | |
| new_processed_length = 0 | |
| else: | |
| end_idx = (input_lengths[i].item() - streaming_config.right_overlap) | |
| wav = wav[:, start_idx * frames_per_token: end_idx * frames_per_token] | |
| return_wavs.append(wav) | |
| new_processed_length = end_idx | |
| if input_lengths[i].item() > streaming_config.left_overlap: | |
| cache_hidden_states[i] = cache_hidden_states[i][-streaming_config.left_overlap:] | |
| new_processed_length -= (input_lengths[i].item() - streaming_config.left_overlap) | |
| processed_lengths.append(new_processed_length) | |
| history_cache.hidden_states = cache_hidden_states | |
| history_cache.processed_lengths = processed_lengths | |
| return return_wavs, history_cache | |