Voice-Acting-TTS / src /mimo_audio_tokenizer /modeling_audio_tokenizer.py
mrfakename's picture
init
ef96930
# Copyright 2025 Xiaomi Corporation.
import math
import numpy as np
import torch
import torch.nn as nn
from flash_attn import flash_attn_varlen_func
from torch.nn import functional as F
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from .configuration_audio_tokenizer import MiMoAudioTokenizerConfig
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update, apply_rotary_pos_emb
from .quantization import ResidualVectorQuantizer
from dataclasses import dataclass, field
from typing import List
def get_sequence_mask(inputs, inputs_length):
if inputs.dim() == 3:
bsz, tgt_len, _ = inputs.size()
else:
bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(
bsz, tgt_len, 1
)
unpacking_index = torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1
return sequence_mask, unpacking_index
def unpack_hidden_states(
hidden_states, lengths, sequence_mask=None, unpacking_index=None
):
bsz = lengths.shape[0]
if sequence_mask is None or unpacking_index is None:
sequence_mask, unpacking_index = get_sequence_mask(hidden_states, lengths)
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
bsz, torch.max(lengths), hidden_states.shape[-1]
)
hidden_states = torch.where(sequence_mask, hidden_states, 0)
return hidden_states
def get_position_ids(lengths):
total_len = lengths.sum()
offset = torch.cat([torch.zeros(1).to(lengths), lengths[:-1].cumsum(dim=0)])
offset = torch.repeat_interleave(offset, lengths)
position_ids = torch.arange(0, total_len).to(offset) - offset
return position_ids
@dataclass
class StreamingConfig:
seg_point: int = field(default=60 * 25)
process_seg_point: bool = field(default=True)
left_overlap: int = field(default=10 * 25)
right_overlap: int = field(default=40)
seg_point_left_overlap: int = field(default=0)
@dataclass
class StreamingCache:
hidden_states: List[torch.Tensor] = field(default=None)
processed_lengths: List[int] = field(default=None)
class ISTFT(nn.Module):
"""
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
See issue: https://github.com/pytorch/pytorch/issues/62323
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
The NOLA constraint is met as we trim padded samples anyway.
Args:
n_fft (int): Size of Fourier transform.
hop_length (int): The distance between neighboring sliding window frames.
win_length (int): The size of window frame and STFT filter.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(
self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
window = torch.hann_window(win_length)
self.register_buffer("window", window)
def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
Args:
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
N is the number of frequency bins, and T is the number of time frames.
Returns:
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
"""
if self.padding == "center":
# Fallback to pytorch native implementation
return torch.istft(
spec,
self.n_fft,
self.hop_length,
self.win_length,
self.window,
center=True,
)
elif self.padding == "same":
pad = (self.win_length - self.hop_length) // 2
else:
raise ValueError("Padding must be 'center' or 'same'.")
assert spec.dim() == 3, "Expected a 3D tensor as input"
B, N, T = spec.shape
# Inverse FFT
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
ifft = ifft * self.window[None, :, None]
# Overlap and Add
output_size = (T - 1) * self.hop_length + self.win_length
y = torch.nn.functional.fold(
ifft,
output_size=(1, output_size),
kernel_size=(1, self.win_length),
stride=(1, self.hop_length),
)[:, 0, 0, pad:-pad]
# Window envelope
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
window_envelope = torch.nn.functional.fold(
window_sq,
output_size=(1, output_size),
kernel_size=(1, self.win_length),
stride=(1, self.hop_length),
).squeeze()[pad:-pad]
# Normalize
assert (window_envelope > 1e-11).all()
y = y / window_envelope
return y
class ISTFTHead(nn.Module):
"""
ISTFT Head module for predicting STFT complex coefficients.
Args:
dim (int): Hidden dimension of the model.
n_fft (int): Size of Fourier transform.
hop_length (int): The distance between neighboring sliding window frames, which should align with
the resolution of the input features.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
super().__init__()
out_dim = n_fft + 2
self.out = torch.nn.Linear(dim, out_dim)
self.istft = ISTFT(
n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the ISTFTHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x = self.out(x).transpose(1, 2)
mag, p = x.chunk(2, dim=1)
mag = torch.exp(mag)
mag = torch.clip(
mag, max=1e2
) # safeguard to prevent excessively large magnitudes
# wrapping happens here. These two lines produce real and imaginary value
x = torch.cos(p)
y = torch.sin(p)
# recalculating phase here does not produce anything new
# only costs time
# phase = torch.atan2(y, x)
# S = mag * torch.exp(phase * 1j)
# better directly produce the complex value
original_dtype = x.dtype
S = mag.float() * (x.float() + 1j * y.float())
audio = self.istft(S)
audio = audio.to(original_dtype)
return audio
class RotaryEmbedding(nn.Module):
def __init__(self, base, dim, max_seq_len, rope_type="default", device=None):
super().__init__()
self.max_seq_len = max_seq_len
self.rope_type = rope_type
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(
device=device, base=base, dim=dim
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[:, None].float().expand(-1, 1).to(x.device)
position_ids_expanded = position_ids[None, :].float()
device_type = (
x.device.type
if isinstance(x.device.type, str) and x.device.type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (
inv_freq_expanded.float() @ position_ids_expanded.float()
).transpose(0, 1)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
LAYER_NORM = {"LayerNorm": nn.LayerNorm, "RMSNorm": RMSNorm}
class Attention(nn.Module):
def __init__(self, embed_dim, num_heads, window_size=(-1, -1), causal=False):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.window_size = window_size
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.causal = causal
def forward(
self,
hidden_states: torch.Tensor,
seq_len: torch.Tensor,
rope_position_embeddings=None,
):
bsz, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(
bsz, self.num_heads, self.head_dim
)
key_states = self.k_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(
bsz, self.num_heads, self.head_dim
)
if rope_position_embeddings is not None:
cos, sin = rope_position_embeddings
query_states = apply_rotary_pos_emb(query_states, cos, sin)
key_states = apply_rotary_pos_emb(key_states, cos, sin)
cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(
torch.int32
)
max_seqlen = torch.max(seq_len).to(torch.int32).detach()
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_len,
cu_len,
max_seqlen,
max_seqlen,
causal=self.causal,
window_size=self.window_size,
)
attn_output = attn_output.reshape(bsz, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class TransformerLayer(nn.Module):
def __init__(
self,
act,
d_model,
encoder_attention_heads,
encoder_ffn_dim,
causal,
ln_type="LayerNorm",
attn_window_size=(-1, -1),
):
super().__init__()
self.embed_dim = d_model
self.self_attn = Attention(
self.embed_dim, encoder_attention_heads, attn_window_size, causal
)
self.self_attn_layer_norm = LAYER_NORM[ln_type](self.embed_dim)
self.activation_fn = act
self.fc1 = nn.Linear(self.embed_dim, encoder_ffn_dim)
self.fc2 = nn.Linear(encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = LAYER_NORM[ln_type](self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
seq_len: torch.Tensor,
rope_position_embeddings: torch.Tensor,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(
hidden_states, seq_len, rope_position_embeddings=rope_position_embeddings
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.fc2(hidden_states)
hidden_states = residual + hidden_states
if (
hidden_states.dtype == torch.float16
or hidden_states.dtype == torch.bfloat16
) and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(
hidden_states, min=-clamp_value, max=clamp_value
)
return hidden_states
class TransformerVocos(nn.Module):
def __init__(self, config: MiMoAudioTokenizerConfig):
super().__init__()
self.config = config
self.max_source_positions = (
self.config.max_audio_seconds
* self.config.sampling_rate
// self.config.hop_length
)
self.embeddings = nn.Linear(config.n_mels, config.vocoder_dim, bias=False)
self.poisition_embedding = RotaryEmbedding(
config.rope_theta,
config.vocoder_dim // config.vocoder_attention_heads,
self.max_source_positions,
self.config.rope_type,
)
self.layers = nn.ModuleList(
[
TransformerLayer(
ACT2FN[self.config.activation_function],
self.config.vocoder_dim,
self.config.vocoder_attention_heads,
self.config.vocoder_intermediate_dim,
causal=False,
ln_type=self.config.ln_type,
attn_window_size=self.config.vocoder_attn_window_size,
)
for _ in range(self.config.vocoder_num_layers)
]
)
self.layer_norm = LAYER_NORM[self.config.ln_type](self.config.vocoder_dim)
self.hop_size = self.config.hop_length
self.head = ISTFTHead(
self.config.vocoder_dim,
self.config.nfft,
self.config.hop_length,
self.config.vocoder_padding,
)
def forward(self, x: torch.Tensor, input_length):
x = x.transpose(1, 2)
attention_mask, unpacking_index = get_sequence_mask(x, input_length)
x = torch.masked_select(x, attention_mask).view(
torch.sum(input_length), self.config.n_mels
)
x = self.embeddings(x)
position_ids = torch.arange(0, x.size(0), device=x.device, dtype=torch.long)
rope_position_embeddings = self.poisition_embedding(x, position_ids)
for idx, layer in enumerate(self.layers):
x = layer(
x, input_length, rope_position_embeddings=rope_position_embeddings
)
x = self.layer_norm(x)
x = unpack_hidden_states(x, input_length, attention_mask, unpacking_index)
x = self.head(x)
output_length = input_length * self.hop_size
return x[:, None, :], output_length
class AudioEncoder(nn.Module):
def __init__(self, config: MiMoAudioTokenizerConfig):
super().__init__()
config._attn_implementation = "flash_attention_2"
self.config = config
self.max_source_positions = (
config.max_audio_seconds * config.sampling_rate // config.hop_length
) // config.stride_size
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.skip_layer_idx = config.encoder_skip_layer_id
self.conv1 = nn.Conv1d(
config.n_mels, config.d_model, kernel_size=config.kernel_size, padding=1
)
self.conv2 = nn.Conv1d(
config.d_model,
config.d_model,
kernel_size=config.kernel_size,
stride=config.stride_size,
padding=1,
)
self.position_embedding = RotaryEmbedding(
config.rope_theta,
config.d_model // config.encoder_attention_heads,
self.max_source_positions,
config.rope_type,
)
self.layers = nn.ModuleList(
[
TransformerLayer(
ACT2FN[config.activation_function],
config.d_model,
config.encoder_attention_heads,
config.encoder_ffn_dim,
causal=self.config.encoder_causal,
ln_type=self.config.ln_type,
attn_window_size=self.config.encoder_attn_window_size,
)
for _ in range(config.encoder_layers)
]
)
self.layer_norm = LAYER_NORM[config.ln_type](config.d_model)
if self.config.avg_pooler != 1:
self.down_sample_layer = nn.Sequential(
nn.Conv1d(
config.d_model,
config.d_model,
config.avg_pooler,
config.avg_pooler,
bias=False,
),
nn.GELU(),
)
self.down_sample_norm = LAYER_NORM[config.ln_type](config.d_model)
else:
self.down_sample_layer = None
if self.config.num_quantizers != 0:
self.quantizer = ResidualVectorQuantizer(
dimension=self.config.d_model,
n_q=self.config.num_quantizers,
bins=self.config.codebook_size,
threshold_ema_dead_code=self.config.threshold_ema_dead_code,
)
else:
self.quantizer = None
def get_features(self, input_features, output_length):
input_features = input_features.to(self.conv1.weight)
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
inputs_embeds = inputs_embeds.permute(0, 2, 1)
bsz, tgt_len, _ = inputs_embeds.size()
hidden_states = inputs_embeds
position_ids = (
get_position_ids(output_length).long().to(input_features.device)
)
rope_position_embeddings = self.position_embedding(
input_features, position_ids
)
attention_mask, unpacking_index = get_sequence_mask(
hidden_states, output_length
)
hidden_states = torch.masked_select(hidden_states, attention_mask).view(
torch.sum(output_length), self.config.d_model
)
skip_connect_hidden_states = 0.0
for idx, encoder_layer in enumerate(self.layers):
hidden_states = encoder_layer(
hidden_states,
output_length,
rope_position_embeddings=rope_position_embeddings,
)
if (self.skip_layer_idx is not None) and idx == self.skip_layer_idx - 1:
skip_connect_hidden_states = hidden_states.clone()
hidden_states += skip_connect_hidden_states
hidden_states = self.layer_norm(hidden_states)
if self.down_sample_layer is not None:
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
bsz, tgt_len, self.config.d_model
)
if hidden_states.size(1) % self.config.avg_pooler:
pad_len = (
self.config.avg_pooler
- hidden_states.size(1) % self.config.avg_pooler
)
hidden_states = torch.nn.functional.pad(
hidden_states, (0, 0, 0, pad_len), mode="constant", value=0.0
)
tgt_len += pad_len
tgt_len = tgt_len // self.config.avg_pooler
hidden_states = self.down_sample_layer(hidden_states.transpose(1, 2))
output_length = (
output_length // self.config.avg_pooler
+ (output_length % self.config.avg_pooler != 0).int()
)
hidden_states = hidden_states.transpose(1, 2)
attention_mask, unpacking_index = get_sequence_mask(
hidden_states, output_length
)
hidden_states = torch.masked_select(hidden_states, attention_mask).view(
torch.sum(output_length), self.config.d_model
)
hidden_states = self.down_sample_norm(hidden_states)
return (
hidden_states,
output_length,
attention_mask,
unpacking_index,
tgt_len,
bsz,
)
def get_output_length(self, mel_len):
tgt_len = mel_len + 3 - self.config.kernel_size
return (tgt_len + 2 - self.config.kernel_size) // self.config.stride_size + 1
@torch.no_grad()
def encode(
self,
input_features,
input_lens=None,
output_length=None,
return_codes_only=False,
n_q=None,
use_quantizer=True,
):
if output_length is None:
output_length = self.get_output_length(input_lens)
input_features = unpack_hidden_states(input_features, input_lens)
hidden_states, output_length, attention_mask, unpacking_index, tgt_len, bsz = (
self.get_features(
input_features=input_features.transpose(1, 2),
output_length=output_length,
)
)
dtype = hidden_states.dtype
if use_quantizer and self.quantizer is not None:
self.quantizer.float()
codes = self.quantizer.encode(hidden_states.float(), n_q=n_q)
if return_codes_only:
return codes, output_length
hidden_states = self.quantizer.decode(codes)
hidden_states = hidden_states.to(dtype)
else:
codes = None
hidden_states_packed = hidden_states.clone()
# unpacking
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
bsz, tgt_len, self.config.d_model
)
hidden_states = torch.where(attention_mask, hidden_states, 0)
return hidden_states, hidden_states_packed, output_length, codes
@torch.no_grad()
def decode_vq(self, codes):
self.quantizer.float()
hidden_states = self.quantizer.decode(codes)
return hidden_states
class CausalConvTranspose1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride):
super().__init__()
self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
self.norm = nn.GroupNorm(1, out_channels)
self.in_channels = in_channels
self.out_channels = out_channels
def forward(self, hidden_states, input_length, output_dim=None):
kernel_size = self.conv.kernel_size[0]
stride = self.conv.stride[0]
bsz = input_length.shape[0]
if output_dim is None:
output_dim = hidden_states.dim()
if hidden_states.dim() <= 2: # unpack sequence to 3d
sequence_mask, unpacking_index = get_sequence_mask(
hidden_states, input_length
)
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
bsz, torch.max(input_length), self.in_channels
)
hidden_states = torch.where(sequence_mask, hidden_states, 0)
hidden_states = hidden_states.transpose(2, 1) # (N, L, C) -> (N, C, L)
hidden_states = self.conv(hidden_states)
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.transpose(2, 1) # (N, C, L) -> (N, L, C)
casual_padding_right = max(0, kernel_size - stride)
hidden_states = hidden_states[
:, : hidden_states.shape[1] - casual_padding_right, :
]
output_length = (input_length - 1) * stride + kernel_size - casual_padding_right
sequence_mask, _ = get_sequence_mask(hidden_states, output_length)
if output_dim <= 2:
hidden_states = torch.masked_select(hidden_states, sequence_mask).view(
-1, self.out_channels
)
else:
hidden_states = torch.where(sequence_mask, hidden_states, 0)
hidden_states = hidden_states[:, : torch.max(output_length), :]
return hidden_states, output_length
class AudioDecoder(nn.Module):
def __init__(self, config: MiMoAudioTokenizerConfig):
super().__init__()
self.config = config
self.max_source_positions = (
self.config.max_audio_seconds
* self.config.sampling_rate
// self.config.hop_length
)
if self.config.avg_pooler != 1:
self.dconv1 = CausalConvTranspose1d(
self.config.d_model,
self.config.d_model,
self.config.avg_pooler,
self.config.avg_pooler,
)
else:
self.dconv1 = None
self.position_embedding = RotaryEmbedding(
config.rope_theta,
config.d_model // config.decoder_attention_heads,
self.max_source_positions,
config.rope_type,
)
self.layers = nn.ModuleList(
[
TransformerLayer(
ACT2FN[self.config.activation_function],
self.config.d_model,
self.config.decoder_attention_heads,
self.config.decoder_ffn_dim,
causal=self.config.decoder_causal,
ln_type=self.config.ln_type,
attn_window_size=self.config.decoder_attn_window_size,
)
for _ in range(self.config.decoder_layers)
]
)
self.layer_norm = LAYER_NORM[config.ln_type](self.config.d_model)
self.dconv2 = CausalConvTranspose1d(
self.config.d_model,
self.config.n_mels,
self.config.decoder_kernel_size,
self.config.decoder_stride_size,
)
self.vocoder = TransformerVocos(config)
def forward(
self,
audio_embed,
input_length,
):
assert audio_embed.shape[-1] == self.config.d_model
audio_embed = audio_embed.to(self.layer_norm.weight)
if self.dconv1 is not None:
audio_embed, output_length = self.dconv1(
audio_embed, input_length, output_dim=3
)
_, tgt_len, _ = audio_embed.size()
else:
output_length = input_length
tgt_len = audio_embed.size(0)
hidden_states = audio_embed
position_ids = (
get_position_ids(output_length).long().to(hidden_states.device)
)
rope_position_embeddings = self.position_embedding(
hidden_states, position_ids
)
# packing hidden states
attention_mask, _ = get_sequence_mask(hidden_states, output_length)
hidden_states = torch.masked_select(hidden_states, attention_mask).view(
torch.sum(output_length), self.config.d_model
)
for idx, encoder_layer in enumerate(self.layers):
hidden_states = encoder_layer(
hidden_states,
output_length,
rope_position_embeddings=rope_position_embeddings,
)
hidden_states = self.layer_norm(hidden_states)
coarse_mel, output_length = self.dconv2(
hidden_states, output_length, output_dim=3
)
recon_wav, wav_length = self.vocoder(
x=coarse_mel.transpose(1, 2),
input_length=output_length,
)
return recon_wav
class MiMoAudioTokenizer(PreTrainedModel):
config_class = MiMoAudioTokenizerConfig
def __init__(self, config: MiMoAudioTokenizerConfig):
super().__init__(config)
self.config = config
self.sampling_rate = config.sampling_rate
self.encoder = AudioEncoder(config=config)
self.decoder = AudioDecoder(config=config)
self.downsample_rate = int(self.config.hop_length * 2 * self.config.avg_pooler)
def get_output_length(self, mel_len):
tgt_len = mel_len + 3 - self.config.kernel_size
return (tgt_len + 2 - self.config.kernel_size) // self.config.stride_size + 1
@torch.no_grad()
def encode(self, mels, input_lens, use_quantizer=True):
input_features = mels
encoder_output_length = self.get_output_length(input_lens)
hidden_states, hidden_states_packed, encoder_output_length, codes = (
self.encoder.encode(
input_features, input_lens=input_lens, use_quantizer=use_quantizer
)
)
return hidden_states, hidden_states_packed, encoder_output_length, codes
@torch.no_grad()
def decode(self, codes):
hidden_states = self.encoder.decode_vq(codes)
output = self.decoder(
hidden_states,
torch.tensor([hidden_states.size(0)], device=hidden_states.device),
)
return output
@torch.no_grad()
def streaming_decode(self, codes_chunks, chunk_input_lengths, history_cache=StreamingCache(), streaming_config=StreamingConfig(), last_chunk=False):
hidden_states = self.encoder.decode_vq(codes_chunks)
input_lengths = []
input_hidden_states = []
start_idx = 0
cache_hidden_states = []
for i, input_length in enumerate(chunk_input_lengths):
sample_hidden_states = hidden_states[start_idx:start_idx + input_length]
start_idx += input_length
if history_cache.hidden_states is not None:
sample_hidden_states = torch.cat([history_cache.hidden_states[i], sample_hidden_states], dim=0)
input_length += history_cache.hidden_states[i].size(0)
input_hidden_states.append(sample_hidden_states)
cache_hidden_states.append(sample_hidden_states.clone())
input_lengths.append(input_length)
input_hidden_states = torch.cat(input_hidden_states, dim=0)
input_lengths = torch.tensor(input_lengths, device=hidden_states.device)
output = self.decoder(input_hidden_states, input_lengths)
return_wavs = []
frames_per_token = self.config.avg_pooler * self.config.stride_size * self.config.hop_length
processed_lengths = []
for i, wav in enumerate(output):
wav = wav.float().detach().cpu()
start_idx = history_cache.processed_lengths[i] if history_cache.processed_lengths is not None else 0
if last_chunk:
return_wavs.append(wav[:, start_idx * frames_per_token:])
new_processed_length = input_lengths[i].item()
elif input_lengths[i].item() <= streaming_config.right_overlap:
return_wavs.append(None)
new_processed_length = 0
else:
end_idx = (input_lengths[i].item() - streaming_config.right_overlap)
wav = wav[:, start_idx * frames_per_token: end_idx * frames_per_token]
return_wavs.append(wav)
new_processed_length = end_idx
if input_lengths[i].item() > streaming_config.left_overlap:
cache_hidden_states[i] = cache_hidden_states[i][-streaming_config.left_overlap:]
new_processed_length -= (input_lengths[i].item() - streaming_config.left_overlap)
processed_lengths.append(new_processed_length)
history_cache.hidden_states = cache_hidden_states
history_cache.processed_lengths = processed_lengths
return return_wavs, history_cache