Vi-SparkTTS-0.5B / modeling_spark_tts.py
ancv's picture
Upload 24 files
5cd61b8 verified
raw
history blame
139 kB
# coding=utf-8
# Copyright 2024 The SparkAudio Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch SparkTTS model."""
import torch
import torch.nn as nn
import numpy as np
import os
import warnings
from pathlib import Path
from typing import Dict, Any, Tuple, Optional, Union
from transformers import PreTrainedModel, AutoModelForCausalLM, Wav2Vec2FeatureExtractor, Wav2Vec2Model
from transformers.utils import logging, requires_backends
from transformers.generation.utils import GenerationMixin
from transformers.configuration_utils import PretrainedConfig
from safetensors.torch import load_file
import torchaudio.transforms as TT # Directly use torchaudio
# # Import necessary components from the original codebase structure
# # These are now defined in _modeling_bicodec_components.py
# from ._modeling_bicodec_components import (
# SpeakerEncoder,
# Encoder,
# Decoder,
# WaveGenerator,
# FactorizedVectorQuantize,
# # Include Snake1d or other base classes if BiCodec.__init__ needs them directly
# )
""" Utility functions for SparkTTS """
import random
import soxr
import soundfile
import torch
import torchaudio
import numpy as np
from pathlib import Path
from typing import Tuple, Dict, Any
from numpy.lib.stride_tricks import sliding_window_view
from omegaconf import OmegaConf # Keep if BiCodec config loading needs it
# --- Token Maps (from sparktts/utils/token_parser.py) ---
TASK_TOKEN_MAP = {
"vc": "<|task_vc|>",
"tts": "<|task_tts|>",
"asr": "<|task_asr|>",
"s2s": "<|task_s2s|>",
"t2s": "<|task_t2s|>",
"understand": "<|task_understand|>",
"caption": "<|task_cap|>",
"controllable_tts": "<|task_controllable_tts|>",
"prompt_tts": "<|task_prompt_tts|>",
"speech_edit": "<|task_edit|>",
}
LEVELS_MAP = {
"very_low": 0,
"low": 1,
"moderate": 2,
"high": 3,
"very_high": 4,
}
LEVELS_MAP_UI = {
1: 'very_low',
2: 'low',
3: 'moderate',
4: 'high',
5: 'very_high'
}
GENDER_MAP = {
"female": 0,
"male": 1,
}
# --- Audio Utils (from sparktts/utils/audio.py) ---
def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray:
temp = np.sort(np.abs(audio))
if len(temp) == 0: # Handle empty audio case
return audio
if temp[-1] < 0.1:
scaling_factor = max(temp[-1], 1e-3)
audio = audio / scaling_factor * 0.1
temp = temp[temp > 0.01]
L = temp.shape[0]
if L <= 10:
return audio
volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)])
if volume == 0: # Avoid division by zero if volume is effectively zero
return audio
audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10)
max_value = np.max(np.abs(audio)) if len(audio) > 0 else 0
if max_value > 1:
audio = audio / max_value
return audio
def load_audio(
adfile: Path,
sampling_rate: int = None,
length: int = None,
volume_normalize: bool = False,
segment_duration: int = None,
) -> np.ndarray:
try:
audio, sr = soundfile.read(adfile, dtype='float32') # Ensure float32
except Exception as e:
raise IOError(f"Could not read audio file {adfile}: {e}")
if audio is None or len(audio) == 0:
raise ValueError(f"Audio file {adfile} is empty or invalid.")
if len(audio.shape) > 1:
audio = audio[:, 0]
if sampling_rate is not None and sr != sampling_rate:
try:
# Ensure input is float64 for soxr
audio = audio.astype(np.float64)
audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ")
# Convert back to float32
audio = audio.astype(np.float32)
sr = sampling_rate
except Exception as e:
raise RuntimeError(f"Failed to resample audio from {sr}Hz to {sampling_rate}Hz: {e}")
if segment_duration is not None:
seg_length = int(sr * segment_duration)
audio = random_select_audio_segment(audio, seg_length)
if volume_normalize:
audio = audio_volume_normalize(audio)
if length is not None:
if audio.shape[0] > length:
audio = audio[:length]
else:
audio = np.pad(audio, (0, int(length - audio.shape[0])), mode='constant')
return audio
def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray:
if audio.shape[0] < length:
audio = np.pad(audio, (0, int(length - audio.shape[0])), mode='constant')
start_index = 0 # If padded, start from beginning
elif audio.shape[0] == length:
start_index = 0 # If exact length, start from beginning
else:
start_index = random.randint(0, audio.shape[0] - length)
end_index = int(start_index + length)
return audio[start_index:end_index]
# --- File Utils (Minimal required) ---
def load_config_yaml(config_path: Path) -> Dict:
"""Loads a YAML configuration file using OmegaConf."""
# Check if path exists
if not Path(config_path).is_file():
raise FileNotFoundError(f"YAML Config file not found: {config_path}")
try:
config = OmegaConf.load(config_path)
# Convert OmegaConf DictConfig to standard Python dict
return OmegaConf.to_container(config, resolve=True)
except Exception as e:
raise IOError(f"Error loading YAML config file {config_path}: {e}")
""" PyTorch SparkTTS BiCodec sub-module definitions."""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import random
from torch.nn.utils import weight_norm, remove_weight_norm
from torch import Tensor, int32
from torch.amp import autocast
from typing import Any, Dict, List, Tuple, Optional
from collections import namedtuple
from functools import wraps, partial
from contextlib import nullcontext
from packaging import version
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange
from einx import get_at # Ensure einx is installed: pip install einx
# ===============================================================
# Start: Content from sparktts/modules/blocks/layers.py
# ===============================================================
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
# Scripting this brings model speed up 1.4x
@torch.jit.script
def snake(x, alpha):
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake1d(nn.Module):
def __init__(self, channels):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
def forward(self, x):
return snake(x, self.alpha)
class ResidualUnit(nn.Module):
def __init__(self, dim: int = 16, dilation: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
Snake1d(dim),
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
Snake1d(dim),
WNConv1d(dim, dim, kernel_size=1),
)
def forward(self, x):
y = self.block(x)
# Adjust padding handling if input and output shapes differ
diff = x.shape[-1] - y.shape[-1]
if diff > 0:
pad = diff // 2
x = x[..., pad:pad + y.shape[-1]] # Ensure shapes match for residual connection
elif diff < 0:
pad = -diff // 2
y = y[..., pad:pad + x.shape[-1]]
return x + y
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
# ===============================================================
# End: Content from sparktts/modules/blocks/layers.py
# ===============================================================
# ===============================================================
# Start: Content from sparktts/modules/blocks/samper.py
# ===============================================================
class SamplingBlock(nn.Module):
"""Sampling block for upsampling or downsampling"""
def __init__(
self,
dim: int,
groups: int = 1,
upsample_scale: int = 1,
downsample_scale: int = 1,
) -> None:
"""
Args:
dim: input dimension
groups: number of groups
upsample_scale: upsampling scale
downsample_scale: downsampling scale
"""
super(SamplingBlock, self).__init__()
self.upsample_scale = upsample_scale
self.downsample_scale = downsample_scale
if self.upsample_scale > 1:
self.de_conv_upsampler = nn.Sequential(
nn.LeakyReLU(0.2),
nn.ConvTranspose1d(
dim,
dim,
kernel_size=upsample_scale * 2,
stride=upsample_scale,
padding=upsample_scale // 2 + upsample_scale % 2,
output_padding=upsample_scale % 2,
groups=groups,
),
)
if self.downsample_scale > 1:
self.conv_downsampler = nn.Sequential(
nn.LeakyReLU(0.2),
nn.Conv1d(
dim,
dim,
kernel_size=2 * downsample_scale,
stride=downsample_scale,
padding=downsample_scale // 2 + downsample_scale % 2,
groups=groups,
),
)
@staticmethod
def repeat_upsampler(x, upsample_scale):
return x.repeat_interleave(upsample_scale, dim=2)
@staticmethod
def skip_downsampler(x, downsample_scale):
return F.avg_pool1d(x, kernel_size=downsample_scale, stride=downsample_scale)
def forward(self, x):
# Input expected as (B, D, T) from VocosBackbone output (B, T, D)
# x = x.transpose(1, 2) # Remove this transpose, input should be (B, D, T)
if self.upsample_scale > 1:
repeat_res = self.repeat_upsampler(x, self.upsample_scale)
deconv_res = self.de_conv_upsampler(x)
# Ensure shapes match for addition
if deconv_res.shape[-1] > repeat_res.shape[-1]:
deconv_res = deconv_res[..., :repeat_res.shape[-1]]
elif repeat_res.shape[-1] > deconv_res.shape[-1]:
repeat_res = repeat_res[..., :deconv_res.shape[-1]]
upmerge_res = repeat_res + deconv_res
else:
upmerge_res = x
repeat_res = x
if self.downsample_scale > 1:
conv_res = self.conv_downsampler(upmerge_res)
skip2_res = self.skip_downsampler(upmerge_res, self.downsample_scale)
skip1_res = self.skip_downsampler(repeat_res, self.downsample_scale)
# Ensure shapes match
min_len = min(conv_res.shape[-1], skip1_res.shape[-1], skip2_res.shape[-1])
conv_res = conv_res[..., :min_len]
skip1_res = skip1_res[..., :min_len]
skip2_res = skip2_res[..., :min_len]
else:
conv_res = upmerge_res
skip2_res = upmerge_res
skip1_res = repeat_res
final_res = conv_res + skip1_res + skip2_res
# Return (B, D, T) for next VocosBackbone
# return final_res.transpose(1, 2) # Remove this, keep (B, D, T)
return final_res
# ===============================================================
# End: Content from sparktts/modules/blocks/samper.py
# ===============================================================
# ===============================================================
# Start: Content from sparktts/modules/speaker/pooling_layers.py
# ===============================================================
class TAP(nn.Module):
"""
Temporal average pooling, only first-order mean is considered
"""
def __init__(self, in_dim=0, **kwargs):
super(TAP, self).__init__()
self.in_dim = in_dim
def forward(self, x):
pooling_mean = x.mean(dim=-1)
# To be compatable with 2D input
pooling_mean = pooling_mean.flatten(start_dim=1)
return pooling_mean
def get_out_dim(self):
# This method seems specific to the original usage, might not be needed by HF
# self.out_dim = self.in_dim
# return self.out_dim
return self.in_dim
class TSDP(nn.Module):
"""
Temporal standard deviation pooling, only second-order std is considered
"""
def __init__(self, in_dim=0, **kwargs):
super(TSDP, self).__init__()
self.in_dim = in_dim
def forward(self, x):
# The last dimension is the temporal axis
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
pooling_std = pooling_std.flatten(start_dim=1)
return pooling_std
def get_out_dim(self):
# self.out_dim = self.in_dim
# return self.out_dim
return self.in_dim
class TSTP(nn.Module):
"""
Temporal statistics pooling, concatenate mean and std, which is used in
x-vector
Comment: simple concatenation can not make full use of both statistics
"""
def __init__(self, in_dim=0, **kwargs):
super(TSTP, self).__init__()
self.in_dim = in_dim
def forward(self, x):
# The last dimension is the temporal axis
pooling_mean = x.mean(dim=-1)
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
pooling_mean = pooling_mean.flatten(start_dim=1)
pooling_std = pooling_std.flatten(start_dim=1)
stats = torch.cat((pooling_mean, pooling_std), 1)
return stats
def get_out_dim(self):
# self.out_dim = self.in_dim * 2
# return self.out_dim
return self.in_dim * 2
class ASTP(nn.Module):
""" Attentive statistics pooling: Channel- and context-dependent
statistics pooling, first used in ECAPA_TDNN.
"""
def __init__(self,
in_dim,
bottleneck_dim=128,
global_context_att=False,
**kwargs):
super(ASTP, self).__init__()
self.in_dim = in_dim
self.global_context_att = global_context_att
# Use Conv1d with stride == 1 rather than Linear, then we don't
# need to transpose inputs.
if global_context_att:
self.linear1 = nn.Conv1d(
in_dim * 3, bottleneck_dim,
kernel_size=1) # equals W and b in the paper
else:
self.linear1 = nn.Conv1d(
in_dim, bottleneck_dim,
kernel_size=1) # equals W and b in the paper
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
kernel_size=1) # equals V and k in the paper
def forward(self, x):
"""
x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
or a 4-dimensional tensor in resnet architecture (B,C,F,T)
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
"""
if len(x.shape) == 4:
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
assert len(x.shape) == 3
if self.global_context_att:
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
context_std = torch.sqrt(
torch.var(x, dim=-1, keepdim=True) + 1e-7).expand_as(x)
x_in = torch.cat((x, context_mean, context_std), dim=1)
else:
x_in = x
# DON'T use ReLU here! ReLU may be hard to converge.
alpha = torch.tanh(
self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
alpha = torch.softmax(self.linear2(alpha), dim=2)
mean = torch.sum(alpha * x, dim=2)
var = torch.sum(alpha * (x**2), dim=2) - mean**2
std = torch.sqrt(var.clamp(min=1e-7))
return torch.cat([mean, std], dim=1)
def get_out_dim(self):
# self.out_dim = 2 * self.in_dim
# return self.out_dim
return self.in_dim * 2
class MHASTP(torch.nn.Module):
""" Multi head attentive statistics pooling
Reference:
Self Multi-Head Attention for Speaker Recognition
https://arxiv.org/pdf/1906.09890.pdf
"""
def __init__(self,
in_dim,
layer_num=2,
head_num=2,
d_s=1,
bottleneck_dim=64,
**kwargs):
super(MHASTP, self).__init__()
assert (in_dim % head_num
) == 0 # make sure that head num can be divided by input_dim
self.in_dim = in_dim
self.head_num = head_num
d_model = int(in_dim / head_num)
channel_dims = [bottleneck_dim for i in range(layer_num + 1)]
if d_s > 1:
d_s = d_model
else:
d_s = 1
self.d_s = d_s
channel_dims[0], channel_dims[-1] = d_model, d_s
heads_att_trans = []
for i in range(self.head_num):
att_trans = nn.Sequential()
for j in range(layer_num - 1): # Use different loop variable
att_trans.add_module(
'att_' + str(j),
nn.Conv1d(channel_dims[j], channel_dims[j + 1], 1, 1))
att_trans.add_module('tanh' + str(j), nn.Tanh())
att_trans.add_module(
'att_' + str(layer_num - 1),
nn.Conv1d(channel_dims[layer_num - 1], channel_dims[layer_num],
1, 1))
heads_att_trans.append(att_trans)
self.heads_att_trans = nn.ModuleList(heads_att_trans)
def forward(self, input):
"""
input: a 3-dimensional tensor in xvector architecture
or a 4-dimensional tensor in resnet architecture
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
"""
if len(input.shape) == 4: # B x C x F x T
input = input.reshape(input.shape[0],
input.shape[1] * input.shape[2],
input.shape[3]) # B x (C*F) x T
assert len(input.shape) == 3
bs, f_dim, t_dim = input.shape
chunks = torch.chunk(input, self.head_num, 1)
# split
chunks_out = []
for i, layer in enumerate(self.heads_att_trans):
att_score = layer(chunks[i])
alpha = F.softmax(att_score, dim=-1)
mean = torch.sum(alpha * chunks[i], dim=2)
var = torch.sum(alpha * chunks[i]**2, dim=2) - mean**2
std = torch.sqrt(var.clamp(min=1e-7))
chunks_out.append(torch.cat((mean, std), dim=1))
out = torch.cat(chunks_out, dim=1)
return out
def get_out_dim(self):
# self.out_dim = 2 * self.in_dim
# return self.out_dim
return self.in_dim * 2
class MQMHASTP(torch.nn.Module):
""" An attentive pooling
Reference:
multi query multi head attentive statistics pooling
https://arxiv.org/pdf/2110.05042.pdf
Args:
in_dim: the feature dimension of input
layer_num: the number of layer in the pooling layer
query_num: the number of querys
head_num: the number of heads
bottleneck_dim: the bottleneck dimension
SA (H = 1, Q = 1, n = 2, d_s = 1) ref:
https://www.danielpovey.com/files/2018_interspeech_xvector_attention.pdf
MHA (H > 1, Q = 1, n = 1, d_s = 1) ref:
https://arxiv.org/pdf/1906.09890.pdf
AS (H = 1, Q > 1, n = 2, d_s = 1) ref:
https://arxiv.org/pdf/1803.10963.pdf
VSA (H = 1, Q > 1, n = 2, d_s = d_h) ref:
http://www.interspeech2020.org/uploadfile/pdf/Mon-2-10-5.pdf
"""
def __init__(self,
in_dim,
layer_num=2,
query_num=2,
head_num=8,
d_s=2,
bottleneck_dim=64,
**kwargs):
super(MQMHASTP, self).__init__()
self.n_query = nn.ModuleList([
MHASTP(in_dim,
layer_num=layer_num,
head_num=head_num,
d_s=d_s,
bottleneck_dim=bottleneck_dim) for i in range(query_num)
])
self.query_num = query_num
self.in_dim = in_dim
def forward(self, input):
"""
input: a 3-dimensional tensor in xvector architecture
or a 4-dimensional tensor in resnet architecture
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
"""
if len(input.shape) == 4: # B x C x F x T
input = input.reshape(input.shape[0],
input.shape[1] * input.shape[2],
input.shape[3]) # B x (C*F) x T
assert len(input.shape) == 3
res = []
for i, layer in enumerate(self.n_query):
res.append(layer(input))
out = torch.cat(res, dim=-1)
return out
def get_out_dim(self):
# self.out_dim = self.in_dim * 2 * self.query_num
# return self.out_dim
return self.in_dim * 2 * self.query_num
# ===============================================================
# End: Content from sparktts/modules/speaker/pooling_layers.py
# ===============================================================
# ===============================================================
# Start: Content from sparktts/modules/blocks/vocos.py
# ===============================================================
# Helper functions needed by VocosBackbone etc.
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d() if callable(d) else d
class AdaLayerNorm(nn.Module):
"""
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
Args:
condition_dim (int): Dimension of the condition.
embedding_dim (int): Dimension of the embeddings.
"""
def __init__(self, condition_dim: int, embedding_dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.dim = embedding_dim
self.scale = nn.Linear(condition_dim, embedding_dim)
self.shift = nn.Linear(condition_dim, embedding_dim)
# Initialize weights similar to original implementation if needed
# torch.nn.init.ones_(self.scale.weight) # Might be default
# torch.nn.init.zeros_(self.shift.weight) # Might be default
if self.scale.bias is not None: nn.init.zeros_(self.scale.bias)
if self.shift.bias is not None: nn.init.zeros_(self.shift.bias)
def forward(self, x: torch.Tensor, cond_embedding: torch.Tensor) -> torch.Tensor:
scale = self.scale(cond_embedding)
shift = self.shift(cond_embedding)
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
x = x * scale.unsqueeze(1) + shift.unsqueeze(1)
return x
class ConvNeXtBlock(nn.Module):
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
Args:
dim (int): Number of input channels.
intermediate_dim (int): Dimensionality of the intermediate layer.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None.
condition_dim (int, optional): Dimension for AdaLayerNorm.
None means non-conditional LayerNorm. Defaults to None.
"""
def __init__(
self,
dim: int,
intermediate_dim: int,
layer_scale_init_value: float,
condition_dim: Optional[int] = None,
):
super().__init__()
self.dwconv = nn.Conv1d(
dim, dim, kernel_size=7, padding=3, groups=dim
) # depthwise conv
self.adanorm = condition_dim is not None
if self.adanorm:
self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6)
else:
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(
dim, intermediate_dim
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(intermediate_dim, dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
if layer_scale_init_value is not None and layer_scale_init_value > 0
else None
)
def forward(
self, x: torch.Tensor, cond_embedding: Optional[torch.Tensor] = None
) -> torch.Tensor:
residual = x
x = self.dwconv(x)
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
if self.adanorm:
assert cond_embedding is not None, "Conditioning embedding required for AdaLayerNorm"
x = self.norm(x, cond_embedding)
else:
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
x = residual + x
return x
class ResBlock1(nn.Module):
"""
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
but without upsampling layers.
Args:
dim (int): Number of input channels.
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
Defaults to (1, 3, 5).
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
Defaults to 0.1.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None.
"""
def __init__(
self,
dim: int,
kernel_size: int = 3,
dilation: Tuple[int, int, int] = (1, 3, 5),
lrelu_slope: float = 0.1,
layer_scale_init_value: Optional[float] = None,
):
super().__init__()
self.lrelu_slope = lrelu_slope
self.convs1 = nn.ModuleList(
[
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=dilation[0],
padding=self.get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=dilation[1],
padding=self.get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=dilation[2],
padding=self.get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs2 = nn.ModuleList(
[
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=1,
padding=self.get_padding(kernel_size, 1),
)
),
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=1,
padding=self.get_padding(kernel_size, 1),
)
),
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=1,
padding=self.get_padding(kernel_size, 1),
)
),
]
)
self.gamma = nn.ParameterList(
[
(
nn.Parameter(
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
)
if layer_scale_init_value is not None
else None
),
(
nn.Parameter(
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
)
if layer_scale_init_value is not None
else None
),
(
nn.Parameter(
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
)
if layer_scale_init_value is not None
else None
),
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
xt = c1(xt)
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
xt = c2(xt)
if gamma is not None:
xt = gamma * xt
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
@staticmethod
def get_padding(kernel_size: int, dilation: int = 1) -> int:
return int((kernel_size * dilation - dilation) / 2)
class Backbone(nn.Module):
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Args:
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
C denotes input features, and L is the sequence length.
Returns:
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
and H denotes the model dimension.
"""
raise NotImplementedError("Subclasses must implement the forward method.")
class VocosBackbone(Backbone):
"""
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
Args:
input_channels (int): Number of input features channels.
dim (int): Hidden dimension of the model.
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
num_layers (int): Number of ConvNeXtBlock layers.
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
condition_dim (int, optional): Dimension for AdaLayerNorm.
None means non-conditional model. Defaults to None.
"""
def __init__(
self,
input_channels: int,
dim: int,
intermediate_dim: int,
num_layers: int,
layer_scale_init_value: Optional[float] = None,
condition_dim: Optional[int] = None,
):
super().__init__()
self.input_channels = input_channels
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
self.adanorm = condition_dim is not None
if self.adanorm:
self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6)
else:
self.norm = nn.LayerNorm(dim, eps=1e-6)
layer_scale_init_value = layer_scale_init_value or 1 / num_layers if num_layers > 0 else None # Handle num_layers=0
self.convnext = nn.ModuleList(
[
ConvNeXtBlock(
dim=dim,
intermediate_dim=intermediate_dim,
layer_scale_init_value=layer_scale_init_value,
condition_dim=condition_dim,
)
for _ in range(num_layers)
]
)
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor, condition: Optional[torch.Tensor] = None) -> torch.Tensor:
# Input x: (B, C, L)
x = self.embed(x)
# After embed: (B, dim, L)
x_transposed = x.transpose(1, 2) # (B, L, dim)
if self.adanorm:
assert condition is not None
norm_out = self.norm(x_transposed, condition)
else:
norm_out = self.norm(x_transposed)
# After norm: (B, L, dim)
x = norm_out.transpose(1, 2) # (B, dim, L)
for conv_block in self.convnext:
x = conv_block(x, condition)
# After convnext blocks: (B, dim, L)
x = self.final_layer_norm(x.transpose(1, 2)) # (B, L, dim)
return x
class VocosResNetBackbone(Backbone):
"""
Vocos backbone module built with ResBlocks.
Args:
input_channels (int): Number of input features channels.
dim (int): Hidden dimension of the model.
num_blocks (int): Number of ResBlock1 blocks.
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
"""
def __init__(
self,
input_channels,
dim,
num_blocks,
layer_scale_init_value=None,
):
super().__init__()
self.input_channels = input_channels
self.embed = weight_norm(
nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
)
layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3 if num_blocks > 0 else None # Handle num_blocks=0
self.resnet = nn.Sequential(
*[
ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
for _ in range(num_blocks)
]
)
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
# Input x: (B, C, L)
x = self.embed(x)
# After embed: (B, dim, L)
x = self.resnet(x)
# After resnet: (B, dim, L)
x = x.transpose(1, 2) # (B, L, dim)
return x
# ===============================================================
# End: Content from sparktts/modules/blocks/vocos.py
# ===============================================================
# ===============================================================
# Start: Content from sparktts/modules/encoder_decoder/feat_decoder.py
# ===============================================================
class Decoder(nn.Module):
"""Decoder module with convnext and upsampling blocks
Args:
sample_ratios (List[int]): sample ratios
example: [2, 2] means upsample by 2x and then upsample by 2x
"""
def __init__(
self,
input_channels: int,
vocos_dim: int,
vocos_intermediate_dim: int,
vocos_num_layers: int,
out_channels: int,
condition_dim: int = None,
sample_ratios: List[int] = [1, 1],
use_tanh_at_final: bool = False,
):
super().__init__()
self.linear_pre = nn.Linear(input_channels, vocos_dim)
upsample_modules = []
current_dim = vocos_dim
for i, ratio in enumerate(sample_ratios):
upsample_modules.append(
nn.Sequential(
SamplingBlock(
dim=current_dim,
groups=current_dim, # Maybe use 1 or fewer groups if dim is high? Check original intent. Using current_dim for now.
upsample_scale=ratio,
),
# Note: The original code used VocosBackbone here, but it changes dims B,T,D -> B,D,T.
# SamplingBlock output is B,D,T, so VocosBackbone input matches.
# However, the VocosBackbone output is B,T,D, which doesn't fit the next SamplingBlock.
# Assuming the intent was to keep B,D,T format between sampling blocks.
# Replacing intermediate VocosBackbone with a simple Conv1d block to maintain format & refine.
nn.Conv1d(current_dim, current_dim, kernel_size=3, padding=1) # Simple refinement layer
# VocosBackbone(
# input_channels=current_dim,
# dim=current_dim,
# intermediate_dim=vocos_intermediate_dim // 2, # Smaller intermediate for efficiency?
# num_layers=2, # Fewer layers
# condition_dim=None,
# )
)
)
# No dimension change expected here if using Conv1d refinement
# If using VocosBackbone, need transpose logic
self.upsample = nn.Sequential(*upsample_modules)
# Final Backbone processes the fully upsampled features
self.vocos_backbone = VocosBackbone(
input_channels=current_dim, # Use the dim after upsampling
dim=vocos_dim, # Map back to main vocos_dim or keep current_dim? Using vocos_dim
intermediate_dim=vocos_intermediate_dim,
num_layers=vocos_num_layers,
condition_dim=condition_dim,
)
self.linear_post = nn.Linear(vocos_dim, out_channels)
self.use_tanh_at_final = use_tanh_at_final
def forward(self, x: torch.Tensor, c: torch.Tensor = None):
"""decoder forward.
Args:
x (torch.Tensor): (batch_size, input_channels, length)
c (torch.Tensor): (batch_size, condition_dim) - Optional condition
Returns:
x (torch.Tensor): (batch_size, out_channels, length_upsampled)
"""
# x: (B, C_in, T)
x = self.linear_pre(x.transpose(1, 2)) # (B, T, vocos_dim)
x = x.transpose(1, 2) # (B, vocos_dim, T)
# Apply upsampling blocks
x = self.upsample(x) # (B, vocos_dim, T_upsampled)
# Apply final backbone
x = self.vocos_backbone(x, condition=c) # (B, T_upsampled, vocos_dim)
x = self.linear_post(x) # (B, T_upsampled, C_out)
x = x.transpose(1, 2) # (B, C_out, T_upsampled)
if self.use_tanh_at_final:
x = torch.tanh(x)
return x
# ===============================================================
# End: Content from sparktts/modules/encoder_decoder/feat_decoder.py
# ===============================================================
# ===============================================================
# Start: Content from sparktts/modules/encoder_decoder/feat_encoder.py
# ===============================================================
class Encoder(nn.Module):
"""Encoder module with convnext and downsampling blocks"""
def __init__(
self,
input_channels: int,
vocos_dim: int,
vocos_intermediate_dim: int,
vocos_num_layers: int,
out_channels: int,
sample_ratios: List[int] = [1, 1],
):
super().__init__()
"""
Encoder module with VocosBackbone and sampling blocks.
Args:
sample_ratios (List[int]): sample ratios
example: [2, 2] means downsample by 2x and then downsample by 2x
"""
# Initial Backbone processing
self.encoder_backbone = VocosBackbone(
input_channels=input_channels,
dim=vocos_dim,
intermediate_dim=vocos_intermediate_dim,
num_layers=vocos_num_layers, # Use main num_layers here
condition_dim=None,
)
downsample_modules = []
current_dim = vocos_dim
for i, ratio in enumerate(sample_ratios):
downsample_modules.append(
nn.Sequential(
SamplingBlock(
dim=current_dim,
groups=current_dim, # Again, check group size. Using current_dim.
downsample_scale=ratio,
),
# Add refinement layer (optional, similar to Decoder logic)
nn.Conv1d(current_dim, current_dim, kernel_size=3, padding=1)
# VocosBackbone( # Or a lighter VocosBackbone
# input_channels=current_dim,
# dim=current_dim,
# intermediate_dim=vocos_intermediate_dim // 2,
# num_layers=2,
# condition_dim=None,
# )
)
)
# No dimension change expected here
self.downsample = nn.Sequential(*downsample_modules)
self.project = nn.Linear(current_dim, out_channels) # Project from the final dimension
def forward(self, x: torch.Tensor, *args):
"""
Args:
x (torch.Tensor): (batch_size, input_channels, length)
Returns:
x (torch.Tensor): (batch_size, out_channels, length_downsampled)
"""
# x: (B, C_in, T)
x = self.encoder_backbone(x) # (B, T, vocos_dim)
x = x.transpose(1, 2) # (B, vocos_dim, T)
# Apply downsampling blocks
x = self.downsample(x) # (B, vocos_dim, T_downsampled)
x = x.transpose(1, 2) # (B, T_downsampled, vocos_dim)
x = self.project(x) # (B, T_downsampled, C_out)
return x.transpose(1, 2) # (B, C_out, T_downsampled)
# ===============================================================
# End: Content from sparktts/modules/encoder_decoder/feat_encoder.py
# ===============================================================
# ===============================================================
# Start: Content from sparktts/modules/encoder_decoder/wave_generator.py
# ===============================================================
class DecoderBlock(nn.Module):
def __init__(
self,
input_dim: int = 16,
output_dim: int = 8,
kernel_size: int = 2,
stride: int = 1,
):
super().__init__()
# Ensure stride is at least 1
stride = max(1, stride)
# Ensure kernel_size is valid for ConvTranspose1d
if kernel_size < stride:
kernel_size = stride # Or handle differently
padding = (kernel_size - stride) // 2
output_padding = stride % 2 if kernel_size % 2 == 0 else 0 # Basic calculation, might need adjustment based on desired output length
# print(f"DecoderBlock - Input: {input_dim}, Output: {output_dim}, Kernel: {kernel_size}, Stride: {stride}, Padding: {padding}, OutputPadding: {output_padding}")
self.block = nn.Sequential(
Snake1d(input_dim),
WNConvTranspose1d(
input_dim,
output_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding, # Add output_padding
),
ResidualUnit(output_dim, dilation=1),
ResidualUnit(output_dim, dilation=3),
ResidualUnit(output_dim, dilation=9),
)
def forward(self, x):
return self.block(x)
class WaveGenerator(nn.Module):
def __init__(
self,
input_channel,
channels,
rates,
kernel_sizes,
d_out: int = 1,
):
super().__init__()
# Add first conv layer
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
# Add upsampling + MRF blocks
current_channels = channels
for i, (kernel_size, stride) in enumerate(zip(kernel_sizes, rates)):
input_dim = current_channels
# Ensure output_dim doesn't go below 1
output_dim = max(1, channels // (2 ** (i + 1)))
layers += [DecoderBlock(input_dim, output_dim, kernel_size, stride)]
current_channels = output_dim # Update for the next block's input
# Add final conv layer
layers += [
Snake1d(current_channels), # Use the final output_dim
WNConv1d(current_channels, d_out, kernel_size=7, padding=3),
nn.Tanh(),
]
self.model = nn.Sequential(*layers)
self.apply(init_weights) # Apply weight initialization
def forward(self, x):
return self.model(x)
# ===============================================================
# End: Content from sparktts/modules/encoder_decoder/wave_generator.py
# ===============================================================
# ===============================================================
# Start: Content from sparktts/modules/fsq/finite_scalar_quantization.py
# ===============================================================
# helper functions moved earlier
def round_ste(z: Tensor) -> Tensor:
"""Round with straight through gradients."""
zhat = z.round()
return z + (zhat - z).detach()
class FSQ(nn.Module):
def __init__(
self,
levels: List[int],
dim: int | None = None,
num_codebooks=1,
keep_num_codebooks_dim: bool | None = None,
scale: float | None = None,
allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64),
channel_first: bool = False, # Added based on usage in ResidualFSQ
projection_has_bias: bool = True,
return_indices=True,
force_quantization_f32=True,
):
super().__init__()
_levels = torch.tensor(levels, dtype=int32)
self.register_buffer("_levels", _levels, persistent=False)
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
self.register_buffer("_basis", _basis, persistent=False)
self.scale = scale # Not used in current implementation, but kept
codebook_dim = len(levels)
self.codebook_dim = codebook_dim
effective_codebook_dim = codebook_dim * num_codebooks
self.num_codebooks = num_codebooks
self.effective_codebook_dim = effective_codebook_dim
# keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
# Force keep_num_codebooks_dim to False if num_codebooks is 1
if num_codebooks == 1:
keep_num_codebooks_dim = False
else:
keep_num_codebooks_dim = default(keep_num_codebooks_dim, True)
# Original assert was checking if num_codebooks > 1 and keep_num_codebooks_dim is False. Let's refine.
# If num_codebooks > 1, keep_num_codebooks_dim must be True based on how rearrange is used.
if num_codebooks > 1 and not keep_num_codebooks_dim:
raise ValueError("If num_codebooks > 1, keep_num_codebooks_dim must be True or None (defaults to True).")
self.keep_num_codebooks_dim = keep_num_codebooks_dim
self.dim = default(dim, len(_levels) * num_codebooks)
self.channel_first = channel_first # Store channel_first setting
has_projections = self.dim != effective_codebook_dim
self.project_in = (
nn.Linear(self.dim, effective_codebook_dim, bias=projection_has_bias)
if has_projections
else nn.Identity()
)
self.project_out = (
nn.Linear(effective_codebook_dim, self.dim, bias=projection_has_bias)
if has_projections
else nn.Identity()
)
self.has_projections = has_projections
self.return_indices = return_indices
if return_indices:
self.codebook_size = self._levels.prod().item()
# Calculate implicit codebook based on current device during forward pass if needed
# For now, calculate assuming CPU and move later if necessary
# implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size, device=self._levels.device)) # Calculate on device
# self.register_buffer("implicit_codebook", implicit_codebook, persistent=False)
self.allowed_dtypes = allowed_dtypes
self.force_quantization_f32 = force_quantization_f32
@property
def implicit_codebook(self):
# Calculate implicit codebook on the fly using the device of _levels
device = self._levels.device
indices = torch.arange(self.codebook_size, device=device)
return self._indices_to_codes(indices)
def bound(self, z, eps: float = 1e-3):
"""Bound `z`, an array of shape (..., d)."""
levels = self._levels.to(z.device) # Ensure levels are on same device
half_l = (levels - 1) * (1 + eps) / 2
offset = torch.where(levels % 2 == 0, 0.5, 0.0)
shift = (offset / half_l).atanh() if torch.any(half_l != 0) else torch.zeros_like(offset) # Avoid div by zero
# Ensure shift is compatible shape for broadcasting
shift = shift.view(1, 1, -1) if z.ndim == 3 else shift # Adjust based on z dims
half_l = half_l.view(1, 1, -1) if z.ndim == 3 else half_l
# Clamp input to avoid inf/-inf in atanh
z_clipped = torch.clamp(z, min=-1.0 + eps, max=1.0 - eps) # Assuming input z is somewhat normalized?
# Original formula might be sensitive, let's try direct clamping.
# return (z + shift).tanh() * half_l - offset
# Alternative clamping approach (from original Jax version logic):
upper_bound = (levels - 1) / 2
lower_bound = -upper_bound
upper_bound = upper_bound.view(1, 1, -1) if z.ndim == 3 else upper_bound
lower_bound = lower_bound.view(1, 1, -1) if z.ndim == 3 else lower_bound
return torch.clamp(z, min=lower_bound, max=upper_bound)
def quantize(self, z):
"""Quantizes z, returns quantized zhat, same shape as z."""
quantized = round_ste(self.bound(z))
levels = self._levels.to(z.device)
half_width = levels // 2 # Renormalize to [-1, 1].
# Avoid division by zero if level is 1
half_width = torch.where(half_width == 0, torch.tensor(1.0, device=z.device), half_width.float())
half_width_view = half_width.view(1, 1, -1) if quantized.ndim == 3 else half_width
return quantized / half_width_view
def _scale_and_shift(self, zhat_normalized):
levels = self._levels.to(zhat_normalized.device)
half_width = levels // 2
half_width_view = half_width.view(1, 1, -1) if zhat_normalized.ndim == 3 else half_width
return (zhat_normalized * half_width_view) + half_width_view
def _scale_and_shift_inverse(self, zhat):
levels = self._levels.to(zhat.device)
half_width = levels // 2
# Avoid division by zero if level is 1
half_width = torch.where(half_width == 0, torch.tensor(1.0, device=zhat.device), half_width.float())
half_width_view = half_width.view(1, 1, -1) if zhat.ndim == 3 else half_width
return (zhat - half_width_view) / half_width_view
def _indices_to_codes(self, indices):
level_indices = self.indices_to_level_indices(indices)
codes = self._scale_and_shift_inverse(level_indices.float()) # Convert level indices to float
return codes
def codes_to_indices(self, zhat):
"""Converts a `code` to an index in the codebook."""
assert zhat.shape[-1] == self.codebook_dim
zhat_scaled = self._scale_and_shift(zhat)
# Ensure basis is on the correct device and dtype, handle potential shape mismatch
basis = self._basis.to(zhat.device, dtype=int32)
basis_view = basis.view(1, 1, -1) if zhat_scaled.ndim == 3 else basis # Match ndim
# Ensure zhat_scaled is integer type for multiplication with basis
product = (zhat_scaled * basis_view).round().int()
return product.sum(dim=-1).to(int32)
def indices_to_level_indices(self, indices):
"""Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings"""
indices_reshaped = rearrange(indices, "... -> ... 1")
basis = self._basis.to(indices.device)
levels = self._levels.to(indices.device)
# Ensure basis and levels match the device and potentially ndim of indices
basis_view = basis.view(*([1] * (indices_reshaped.ndim - 1)), -1)
levels_view = levels.view(*([1] * (indices_reshaped.ndim - 1)), -1)
codes_non_centered = (indices_reshaped // basis_view) % levels_view
return codes_non_centered
# indices_to_codes is now handled by implicit_codebook property + project_out if needed
def forward(self, z):
"""
einstein notation
b - batch
... - sequence, spatial dimensions
d - feature dimension
c - number of codebook dim (within a single quantizer)
g - number of quantizers (groups) - handled by ResidualFSQ/GroupedResidualFSQ
"""
# Input z can be (b d ...) or (b ... d)
# self.channel_first determines the expected input format for projection
if self.channel_first:
# Expects (b d ...)
if z.ndim > 2: # Has spatial/temporal dims
z = rearrange(z, "b d ... -> b ... d")
z, ps = pack([z], "b * d")
# else: z is (b d) -> processed directly by linear
else:
# Expects (b ... d)
if z.ndim > 2:
z, ps = pack([z], "b * d")
# else: z is (b d) -> processed directly by linear
assert (
z.shape[-1] == self.dim
), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
# Project in
z_projected = self.project_in(z) # (b ... effective_codebook_dim)
# Reshape for codebooks if num_codebooks > 1
if self.num_codebooks > 1:
z_reshaped = rearrange(z_projected, "b ... (c d) -> b ... c d", c=self.num_codebooks)
else:
# Add a dummy codebook dim for consistent processing
z_reshaped = rearrange(z_projected, "b ... d -> b ... 1 d")
# Force quantization step to be full precision or not
force_f32 = self.force_quantization_f32
quantization_context = (
partial(autocast, "cuda", enabled=False) if force_f32 else nullcontext
)
codes = None
indices = None
with quantization_context():
orig_dtype = z_reshaped.dtype
if force_f32 and orig_dtype not in self.allowed_dtypes:
z_for_quant = z_reshaped.float()
else:
z_for_quant = z_reshaped
codes = self.quantize(z_for_quant) # (b ... c d)
if self.return_indices:
indices = self.codes_to_indices(codes) # (b ... c)
# Convert codes back to original dtype if changed
codes = codes.type(orig_dtype)
# Reshape codes back and project out
if self.num_codebooks > 1:
codes_reshaped = rearrange(codes, "b ... c d -> b ... (c d)")
else:
codes_reshaped = rearrange(codes, "b ... 1 d -> b ... d")
out = self.project_out(codes_reshaped) # (b ... dim)
# Restore original spatial/temporal dimensions
if z.ndim > 2: # If we packed dimensions
out = unpack(out, ps, "b * d")[0]
if self.return_indices:
indices = unpack(indices, ps, "b * c")[0]
# Restore channel dimension if needed
if self.channel_first and out.ndim > 2:
out = rearrange(out, "b ... d -> b d ...")
if self.return_indices and indices.ndim > 1: # Check indices ndim
# Indices shape (b ... c), need to decide how to handle channel dim
# Often indices might not need channel dim, depends on usage
# If indices are e.g. (b H W c), permuting might be complex.
# Keeping indices as (b ... c) for now.
pass
# Remove the dummy codebook dim from indices if num_codebooks was 1
if self.return_indices and self.num_codebooks == 1 and not self.keep_num_codebooks_dim:
indices = indices.squeeze(-1)
return out, indices
# ===============================================================
# End: Content from sparktts/modules/fsq/finite_scalar_quantization.py
# ===============================================================
# ===============================================================
# Start: Content from sparktts/modules/fsq/residual_fsq.py
# ===============================================================
# Helper functions needed by ResidualFSQ
def is_distributed():
return dist.is_initialized() and dist.get_world_size() > 1
def get_maybe_sync_seed(device, max_size=10_000):
rand_int = torch.randint(0, max_size, (), device=device)
if is_distributed():
# Ensure rand_int is on the correct device for all_reduce
if rand_int.device != device:
rand_int = rand_int.to(device)
dist.all_reduce(rand_int)
return rand_int.item()
def round_up_multiple(num, mult):
# Ensure mult is positive
if mult <= 0:
return num
# Use ceiling division
return (num + mult - 1) // mult * mult
class ResidualFSQ(nn.Module):
"""Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
def __init__(
self,
*,
levels: List[int],
num_quantizers,
dim=None,
# is_channel_first=False, # Handled inside FSQ now
quantize_dropout=False,
quantize_dropout_cutoff_index=0,
quantize_dropout_multiple_of=1,
channel_first: bool = False, # Pass channel_first to FSQ
**kwargs, # Pass remaining kwargs to FSQ
):
super().__init__()
codebook_dim = len(levels)
dim = default(dim, codebook_dim)
requires_projection = codebook_dim != dim
self.project_in = (
nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
)
self.has_projections = requires_projection
self.channel_first = channel_first # Store for potential shape adjustments if needed later
self.num_quantizers = num_quantizers
self.levels = levels
self.layers = nn.ModuleList([])
levels_tensor = torch.Tensor(levels)
scales = []
for ind in range(num_quantizers):
# Calculate scale: (levels - 1) is max value range (- (l-1)/2 to +(l-1)/2)
# Residual is divided by scale before quantization
# Effective scale for quantizer 'ind' is (levels - 1)^ind ? Needs check.
# Original paper scale seems different. Let's stick to FSQ handling scale internally if needed.
# Using scale = 1.0 for now, assuming FSQ handles normalization.
scale_value = 1.0 # ((levels_tensor - 1)**-ind) - Check this logic
scales.append(scale_value)
# Pass channel_first to FSQ
fsq = FSQ(levels=levels, dim=codebook_dim, channel_first=channel_first, **kwargs)
self.layers.append(fsq)
# Check if FSQ layers have projections internally. ResidualFSQ should handle overall projection.
assert all([not fsq.has_projections for fsq in self.layers]), "FSQ layers within ResidualFSQ should not have internal projections."
self.codebook_size = self.layers[0].codebook_size
# Using scale = 1.0, so register_buffer might not be needed, or store 1.0s
# self.register_buffer("scales", torch.Tensor(scales), persistent=False)
# If scales are needed, they should likely be parameters or calculated differently.
# For now, assuming FSQ normalizes correctly and scale is 1.0 here.
self.quantize_dropout = quantize_dropout and num_quantizers > 1
assert quantize_dropout_cutoff_index >= 0
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
@property
def codebooks(self):
# Codebooks are implicit in FSQ, access via property
codebooks = [layer.implicit_codebook for layer in self.layers]
codebooks = torch.stack(codebooks, dim=0)
return codebooks
def get_codes_from_indices(self, indices):
# indices shape: (b ... q) or (b q ...) depending on usage
num_dims = indices.ndim
q_dim = -1 # Assume last dim is quantizer dim by default
# Find the quantizer dimension (q)
for i in range(num_dims):
if indices.shape[i] == self.num_quantizers:
q_dim = i
break
if q_dim == -1 and self.num_quantizers == 1 and indices.shape[-1] != 1:
# If only 1 quantizer, indices might not have the quantizer dim explicitly
indices = indices.unsqueeze(-1) # Add the quantizer dim
q_dim = -1
elif q_dim == -1:
raise ValueError(f"Could not find quantizer dimension ({self.num_quantizers}) in indices shape {indices.shape}")
# Ensure q_dim is the last dimension for processing
if q_dim != num_dims - 1:
permute_dims = list(range(num_dims))
permute_dims.pop(q_dim)
permute_dims.append(q_dim)
indices = indices.permute(*permute_dims)
batch_shape = indices.shape[:-1] # Shape before the quantizer dim
indices = indices.reshape(-1, self.num_quantizers) # Flatten batch/spatial dims
# Handle dropout indices (-1)
if indices.max() >= self.codebook_size:
raise ValueError(f"Invalid index found in indices: {indices.max()}. Max allowed is {self.codebook_size - 1}.")
if indices.min() < -1:
raise ValueError(f"Invalid index found in indices: {indices.min()}. Min allowed is -1 (dropout).")
mask = indices == -1
effective_indices = indices.masked_fill(mask, 0) # Use 0 for dropout indices temporarily
all_codes = []
# Iterate through each quantizer layer
for i in range(self.num_quantizers):
layer_indices = effective_indices[:, i]
# Use the FSQ layer's method to convert indices to codes (handles normalization)
# Need to ensure indices_to_codes exists and works correctly in FSQ
# Assuming FSQ.indices_to_codes takes (batch,) indices and returns (batch, codebook_dim) codes
layer_codes = self.layers[i].indices_to_codes(layer_indices) # This needs correct FSQ method
all_codes.append(layer_codes)
all_codes_tensor = torch.stack(all_codes, dim=0) # (q, b_flat, d)
# Mask out dropout codes
mask_expanded = mask.permute(1, 0).unsqueeze(-1) # (q, b_flat, 1)
all_codes_tensor = all_codes_tensor.masked_fill(mask_expanded, 0.0)
# Reshape back to original batch/spatial shape
all_codes_tensor = all_codes_tensor.reshape(self.num_quantizers, *batch_shape, -1) # (q, b ... d)
# Restore original q_dim position if it was changed
if q_dim != num_dims - 1:
# Need inverse permutation
inv_permute_dims = list(range(num_dims)) # Start with 0, 1, ..., num_dims-1
inv_permute_dims.insert(q_dim, num_dims) # Insert the last dim (q) at the original position
inv_permute_dims.pop() # Remove the last element
# Permute from (q, b ... d) -> (b ... q ... d) - careful with dims
# Example: Input (b h w q), processed to (q, b*h*w), output (q, b*h*w, d)
# Reshaped to (q, b, h, w, d)
# Want (b, h, w, q, d) -> Need to confirm this logic
# Let's assume output shape (q, b, ..., d) is desired for summation later.
pass # Keep as (q, b ... d) for now
return all_codes_tensor
def get_output_from_indices(self, indices):
# indices shape: (b ... q)
codes = self.get_codes_from_indices(indices) # Output: (q, b ... d)
codes_summed = reduce(codes, "q b ... d -> b ... d", "sum")
# Project back to original dimension
output = self.project_out(codes_summed)
# Handle channel first if necessary for the final output
if self.channel_first and output.ndim > 2:
# Assumes input was (b d ...), so output should be too
output = rearrange(output, "b ... d -> b d ...")
return output
def forward(self, x, return_all_codes=False, rand_quantize_dropout_fixed_seed=None):
num_quant, quant_dropout_multiple_of, device = (
self.num_quantizers,
self.quantize_dropout_multiple_of,
x.device,
)
# handle channel first input if necessary for projection
original_shape = x.shape
if self.channel_first:
if x.ndim > 2: # Has spatial/temporal dims
x = rearrange(x, "b d ... -> b ... d")
x, ps = pack([x], "b * d")
# else: x is (b d), processed directly
else:
# Input is (b ... d)
if x.ndim > 2:
x, ps = pack([x], "b * d")
# else: x is (b d), processed directly
# maybe project in
projected_x = self.project_in(x) # (b ... codebook_dim)
quantized_out = 0.0
residual = projected_x # Start residual from projected input
all_indices = []
should_quantize_dropout = self.training and self.quantize_dropout
# sample a layer index at which to dropout further residual quantization
# also prepare null indices
rand_quantize_dropout_index = num_quant # Default to no dropout
if should_quantize_dropout:
if not exists(rand_quantize_dropout_fixed_seed):
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
rand = random.Random(rand_quantize_dropout_fixed_seed)
# Ensure cutoff index is valid
valid_cutoff = max(0, self.quantize_dropout_cutoff_index)
rand_quantize_dropout_index = rand.randrange(valid_cutoff, num_quant)
if quant_dropout_multiple_of != 1:
rand_quantize_dropout_index = (
round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1
)
# Clamp index to be within valid range
rand_quantize_dropout_index = min(rand_quantize_dropout_index, num_quant - 1)
# Null indices shape should match the batch/spatial dims of x before pack
null_indices_shape = list(x.shape[:-1]) # All dims except last feature dim
null_indices = torch.full(null_indices_shape, -1, device=device, dtype=torch.long)
# go through the layers
# Assuming scale is handled within FSQ or is 1.0 here
# scales = self.scales.to(device)
for quantizer_index, layer in enumerate(self.layers):
# scale = scales[quantizer_index] # If using external scales
if quantizer_index > rand_quantize_dropout_index:
# Append null indices matching the shape of valid indices from FSQ
# FSQ returns indices shape (b ...) or (b ... c) -> need (b ...)
# Use the pre-calculated null_indices
all_indices.append(null_indices)
continue
# Pass residual to the quantizer layer
# Assume FSQ takes (b ... d) or (b d ...) based on its channel_first setting
# Here, residual is (b ... codebook_dim)
quantized, indices = layer(residual) # layer should handle channel_first internally
# residual = residual - quantized.detach() # Update residual BEFORE summing output
# quantized_out = quantized_out + quantized # Sum the quantized part
# Algorithm 1 from paper:
# Input: x
# residual = x
# codes = []
# for q in quantizers:
# x_q, indices = q(residual) # Quantize
# residual = residual - x_q # Update residual (use x_q directly, not detached?) - Check paper/encodec. Using detached version.
# codes.append(indices)
# x_hat = sum(x_q for each layer?) - No, final quantized output is reconstructed from indices.
# Let's follow common implementation: sum quantized outputs, update residual with detached quantized
quantized_detached = quantized.detach()
residual = residual - quantized_detached
quantized_out = quantized_out + quantized # Sum quantized outputs from each layer
# Store indices
if indices is None:
raise ValueError(f"FSQ layer {quantizer_index} did not return indices.")
all_indices.append(indices)
# project out the summed quantized output
final_quantized_out = self.project_out(quantized_out) # (b ... dim)
# stack all indices
all_indices = torch.stack(all_indices, dim=-1) # (b ... q)
# Restore original shape if packed
if x.ndim > 2: # If we packed dimensions
final_quantized_out = unpack(final_quantized_out, ps, "b * d")[0]
all_indices = unpack(all_indices, ps, "b * q")[0]
# Restore channel dimension if needed
if self.channel_first and final_quantized_out.ndim > 2:
final_quantized_out = rearrange(final_quantized_out, "b ... d -> b d ...")
# Decide how to handle indices shape. Keep as (b ... q) or (b q ...)?
# Keeping as (b ... q) seems more common.
# all_indices = rearrange(all_indices, "b ... q -> b q ...") # Optional rearrange
# return
ret = (final_quantized_out, all_indices)
if not return_all_codes:
return ret
# Return all codes (reconstructed from indices)
# Input to get_codes_from_indices should be (b ... q)
all_codes = self.get_codes_from_indices(all_indices) # Output (q, b ... d)
# Maybe reshape all_codes to match input shape conventions?
# If input was channel_first (b d ...), maybe output codes as (q b d ...)?
if self.channel_first and all_codes.ndim > 3:
all_codes = rearrange(all_codes, "q b ... d -> q b d ...")
return (*ret, all_codes)
# ===============================================================
# End: Content from sparktts/modules/fsq/residual_fsq.py
# ===============================================================
# ===============================================================
# Start: Content from sparktts/modules/speaker/ecapa_tdnn.py
# ===============================================================
class Res2Conv1dReluBn(nn.Module):
"""
in_channels == out_channels == channels
"""
def __init__(
self,
channels,
kernel_size=1,
stride=1,
padding=0,
dilation=1,
bias=True,
scale=4,
):
super().__init__()
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
self.scale = scale
self.width = channels // scale
self.nums = scale if scale == 1 else scale - 1
self.convs = []
self.bns = []
for i in range(self.nums):
self.convs.append(
nn.Conv1d(
self.width,
self.width,
kernel_size,
stride,
padding,
dilation,
bias=bias,
)
)
self.bns.append(nn.BatchNorm1d(self.width))
self.convs = nn.ModuleList(self.convs)
self.bns = nn.ModuleList(self.bns)
def forward(self, x):
out = []
spx = torch.split(x, self.width, 1)
sp = spx[0]
# Enumerate starts from 0, matching list indices
for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
# Order: conv -> relu -> bn
if i >= 1:
sp = sp + spx[i] # Residual connection within block parts
sp = conv(sp)
sp = bn(F.relu(sp)) # Apply ReLU before BatchNorm
out.append(sp)
if self.scale != 1:
# Append the last chunk without processing if scale > 1
out.append(spx[self.nums])
out = torch.cat(out, dim=1)
return out
""" Conv1d + BatchNorm1d + ReLU """
class Conv1dReluBn(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
dilation=1,
bias=True,
):
super().__init__()
self.conv = nn.Conv1d(
in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias
)
self.bn = nn.BatchNorm1d(out_channels)
def forward(self, x):
# Original: bn(relu(conv(x)))
# ECAPA Paper Figure/Desc seems to suggest conv -> bn -> relu ? Check Res2Net paper/ECAPA details.
# Sticking to original code's bn(relu(conv(x))) for now.
return self.bn(F.relu(self.conv(x)))
""" The SE connection of 1D case. """
class SE_Connect(nn.Module):
def __init__(self, channels, se_bottleneck_dim=128):
super().__init__()
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
def forward(self, x):
# x shape: (B, C, T)
out = x.mean(dim=2) # Global average pooling over time -> (B, C)
out = F.relu(self.linear1(out))
out = torch.sigmoid(self.linear2(out))
out = x * out.unsqueeze(2) # (B, C, T) * (B, C, 1) -> (B, C, T)
return out
""" SE-Res2Block of the ECAPA-TDNN architecture. """
class SE_Res2Block(nn.Module):
def __init__(self, channels, kernel_size, stride, padding, dilation, scale):
super().__init__()
self.se_res2block = nn.Sequential(
Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
Res2Conv1dReluBn(
channels, kernel_size, stride, padding, dilation, scale=scale
),
Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
SE_Connect(channels),
)
def forward(self, x):
return x + self.se_res2block(x)
class ECAPA_TDNN(nn.Module):
def __init__(
self,
channels=512,
feat_dim=80,
embed_dim=192,
pooling_func="ASTP",
global_context_att=False,
emb_bn=False,
):
super().__init__()
self.layer1 = Conv1dReluBn(feat_dim, channels, kernel_size=5, padding=2)
self.layer2 = SE_Res2Block(
channels, kernel_size=3, stride=1, padding=2, dilation=2, scale=8
)
self.layer3 = SE_Res2Block(
channels, kernel_size=3, stride=1, padding=3, dilation=3, scale=8
)
self.layer4 = SE_Res2Block(
channels, kernel_size=3, stride=1, padding=4, dilation=4, scale=8
)
cat_channels = channels * 3
# The output channels after conv depends on the pooling layer input expectation
# Original paper uses 1536. Let's assume pooling expects 1536.
self.conv = nn.Conv1d(cat_channels, cat_channels, kernel_size=1) # Keep channels same for pooling
# Dynamically get pooling class based on string name from pooling_layers (defined earlier)
if pooling_func == "TAP": pooling_layer = TAP
elif pooling_func == "TSDP": pooling_layer = TSDP
elif pooling_func == "TSTP": pooling_layer = TSTP
elif pooling_func == "ASTP": pooling_layer = ASTP
elif pooling_func == "MHASTP": pooling_layer = MHASTP
elif pooling_func == "MQMHASTP": pooling_layer = MQMHASTP
else: raise ValueError(f"Unsupported pooling function: {pooling_func}")
self.pool = pooling_layer(
in_dim=cat_channels, # Pooling operates on the output of self.conv
global_context_att=global_context_att # Pass context flag if relevant (ASTP)
# Add other necessary kwargs for specific pooling layers if needed
)
# self.pool_out_dim = self.pool.get_out_dim() # Get output dim from pooling layer
# Use standard way to get output dim if get_out_dim not standard
# For TSTP/ASTP etc., it's usually 2 * in_dim
if hasattr(self.pool, 'get_out_dim'):
self.pool_out_dim = self.pool.get_out_dim()
elif isinstance(self.pool, (TSTP, ASTP, MHASTP, MQMHASTP)):
# Assuming these double the input dimension
self.pool_out_dim = cat_channels * (2 * getattr(self.pool, 'query_num', 1) if isinstance(self.pool, MQMHASTP) else 2)
else: # TAP, TSDP
self.pool_out_dim = cat_channels
self.bn = nn.BatchNorm1d(self.pool_out_dim)
self.linear = nn.Linear(self.pool_out_dim, embed_dim)
self.emb_bn = emb_bn
if emb_bn: # better in SSL for SV
self.bn2 = nn.BatchNorm1d(embed_dim)
else:
self.bn2 = nn.Identity()
def forward(self, x, return_latent=False):
# Input x expected as (B, T, F) e.g., mels
x = x.permute(0, 2, 1) # (B, T, F) -> (B, F, T)
out1 = self.layer1(x)
out2 = self.layer2(out1)
out3 = self.layer3(out2)
out4 = self.layer4(out3)
# Concat features from layers 2, 3, 4
out = torch.cat([out2, out3, out4], dim=1) # (B, 3*channels, T)
latent = F.relu(self.conv(out)) # (B, 3*channels, T)
# Pooling expects (B, F, T)
pooled_out = self.pool(latent) # (B, pool_out_dim)
bn_out = self.bn(pooled_out)
embedding = self.linear(bn_out) # (B, embed_dim)
if self.emb_bn:
embedding = self.bn2(embedding)
if return_latent:
# Return the embedding and the features before pooling
return embedding, latent # latent shape (B, 3*channels, T)
return embedding # Return only the final embedding
# Factory functions (optional, but keep if used elsewhere)
def ECAPA_TDNN_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
return ECAPA_TDNN(
channels=1024,
feat_dim=feat_dim,
embed_dim=embed_dim,
pooling_func=pooling_func,
emb_bn=emb_bn,
)
def ECAPA_TDNN_GLOB_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
return ECAPA_TDNN(
channels=1024,
feat_dim=feat_dim,
embed_dim=embed_dim,
pooling_func=pooling_func,
global_context_att=True,
emb_bn=emb_bn,
)
def ECAPA_TDNN_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
return ECAPA_TDNN(
channels=512,
feat_dim=feat_dim,
embed_dim=embed_dim,
pooling_func=pooling_func,
emb_bn=emb_bn,
)
def ECAPA_TDNN_GLOB_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
return ECAPA_TDNN(
channels=512,
feat_dim=feat_dim,
embed_dim=embed_dim,
pooling_func=pooling_func,
global_context_att=True,
emb_bn=emb_bn,
)
# ===============================================================
# End: Content from sparktts/modules/speaker/ecapa_tdnn.py
# ===============================================================
# ===============================================================
# Start: Content from sparktts/modules/speaker/perceiver_encoder.py
# ===============================================================
# Helper functions for Perceiver/Attention
def exists(val): # Redefined earlier
return val is not None
def once(fn): # Redefined earlier
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called: return
called = True
return fn(x)
return inner
print_once = once(print)
class Attend(nn.Module):
def __init__(self, dropout=0.0, causal=False, use_flash=False):
super().__init__()
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.causal = causal
self.register_buffer("mask", None, persistent=False)
self.use_flash = use_flash
can_use_flash = hasattr(F, 'scaled_dot_product_attention') and use_flash
if can_use_flash:
print_once("Using Flash Attention for Perceiver.")
else:
if use_flash: print_once("Flash Attention requested but not available/enabled.")
self.use_flash = False # Disable if not available
# Flash attention config (simplified)
self.efficient_config = namedtuple("EfficientAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"])
# Set default configs, actual backend selection happens in F.scaled_dot_product_attention
self.cpu_config = self.efficient_config(True, True, True) # Default for CPU
self.cuda_config = self.efficient_config(True, True, True) # Default for CUDA
def get_mask(self, n, device):
if exists(self.mask) and self.mask.shape[-1] >= n and self.mask.device == device:
return self.mask[:n, :n]
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
self.register_buffer("mask", mask, persistent=False)
return mask
def flash_attn(self, q, k, v, mask=None):
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
# Expand KV if needed (for multi-query attention, though Perceiver might use standard MHA)
if k.ndim == 3: # (b n_kv d) -> (b h n_kv d) ? No, needs (b h n_kv d_head)
# Assume k/v are already (b h n d_head) or need different handling
pass
if v.ndim == 3:
pass
# Format mask for flash attention (B, N_q, N_kv) or (B, H, N_q, N_kv)
flash_mask = None
if exists(mask):
# mask shape (b, n_kv) -> needs (b, 1, n_q, n_kv) or (b, h, n_q, n_kv) ?
# Check documentation. For key padding mask, usually (B, N_kv).
# Needs expansion. Let's assume (B, H, N_q, N_kv) for safety.
if mask.ndim == 2: # (b, n_kv)
flash_mask = rearrange(mask, "b j -> b 1 1 j")
# Flash attention expects additive mask (-inf for masked) not boolean? Check.
# F.scaled_dot_product_attention takes boolean mask with attn_mask arg.
flash_mask = flash_mask.expand(-1, heads, q_len, -1) # (b h n_q n_kv)
# Use ~mask because True means *mask out* in flash attn's attn_mask.
flash_mask = ~flash_mask
elif mask.ndim == 4 and mask.shape[1] == 1: # Maybe already expanded (b 1 1 n_kv)
flash_mask = mask.expand(-1, heads, q_len, -1)
flash_mask = ~flash_mask
else:
# Assuming mask might already be correctly shaped (e.g., B, H, Nq, Nkv boolean)
flash_mask = ~mask # Invert mask if boolean
# pytorch 2.0 flash attn: q, k, v, attn_mask, dropout_p, is_causal
# attn_mask should be boolean where True indicates masking.
out = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=flash_mask if exists(flash_mask) else None, # Pass boolean mask
dropout_p=self.dropout if self.training else 0.0,
is_causal=self.causal # Pass causal flag directly
)
return out
def forward(self, q, k, v, mask=None):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (query, key/value)
d - feature dimension (d_head)
"""
n, device = q.shape[-2], q.device
scale = q.shape[-1] ** -0.5
if self.use_flash:
return self.flash_attn(q, k, v, mask=mask)
# Manual Attention Calculation
kv_einsum_eq = "b h j d" # Assuming k, v are always (b h n d)
# similarity
sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
# key padding mask
if exists(mask):
# mask shape (b, j) -> (b, 1, 1, j)
mask_value = -torch.finfo(sim.dtype).max
mask = rearrange(mask, "b j -> b 1 1 j")
sim = sim.masked_fill(~mask, mask_value) # Mask where mask is False
# causal mask (Not typically used in Perceiver cross-attention)
if self.causal:
causal_mask = self.get_mask(n, device) # (i, j)
sim = sim.masked_fill(causal_mask, mask_value)
# attention
attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)
# aggregate values
out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
return out
# Need Sequential, default, RMSNorm, GEGLU, FeedForward, Attention for PerceiverResampler
def Sequential(*mods): # Redefined earlier
return nn.Sequential(*filter(exists, mods))
class RMSNorm(nn.Module):
def __init__(self, dim, scale=True, dim_cond=None):
super().__init__()
self.cond = exists(dim_cond)
# Conditional LayerNorm not used in PerceiverResampler, simplify
# self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(dim)) if scale else None
def forward(self, x, cond=None): # Remove cond argument if not used
gamma = default(self.gamma, torch.tensor(1.0, device=x.device)) # Ensure gamma is tensor
# Note: F.normalize normalizes across the *last* dimension by default
normed_x = F.normalize(x, dim=-1)
return normed_x * self.scale * gamma
class CausalConv1d(nn.Conv1d): # Already defined earlier
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
kernel_size = self.kernel_size[0]
dilation = self.dilation[0]
stride = self.stride[0]
assert stride == 1
self.causal_padding = dilation * (kernel_size - 1)
def forward(self, x):
# Input x: (B, C, T)
causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0)
return super().forward(causal_padded_x)
class GEGLU(nn.Module): # Already defined earlier
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.gelu(gate) * x
def FeedForward(dim, mult=4, causal_conv=False): # Already defined earlier
dim_inner = int(dim * mult * 2 / 3)
conv = None
if causal_conv:
conv = nn.Sequential(
Rearrange("b n d -> b d n"),
CausalConv1d(dim_inner, dim_inner, 3),
Rearrange("b d n -> b n d"),
)
return Sequential(
nn.Linear(dim, dim_inner * 2, bias=False), # Bias False often used in transformers
GEGLU(),
conv,
nn.Linear(dim_inner, dim, bias=False) # Bias False
)
class Attention(nn.Module):
def __init__(
self,
dim,
*,
dim_context=None,
causal=False,
dim_head=64,
heads=8,
dropout=0.0,
use_flash=False,
cross_attn_include_queries=False,
):
super().__init__()
# self.scale = dim_head**-0.5 # scale is handled by Attend or flash attn
self.heads = heads
self.cross_attn_include_queries = cross_attn_include_queries
dim_inner = dim_head * heads
dim_context = default(dim_context, dim)
self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash)
self.to_q = nn.Linear(dim, dim_inner, bias=False)
# Combine K and V projection for efficiency
self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False)
self.to_out = nn.Linear(dim_inner, dim, bias=False)
def forward(self, x, context=None, mask=None):
h, has_context = self.heads, exists(context)
# x shape: (b, n_q, d)
# context shape: (b, n_kv, d_ctx)
context = default(context, x) # Use self if context not provided
if has_context and self.cross_attn_include_queries:
# Prepend queries to context for attention calculation
context = torch.cat((x, context), dim=-2) # (b, n_q + n_kv, d_ctx) - ensure dims match
# Project q, k, v
q = self.to_q(x)
k, v = self.to_kv(context).chunk(2, dim=-1)
# Reshape for multi-head attention
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
# Attend
out = self.attend(q, k, v, mask=mask) # mask should be (b, n_kv)
# Combine heads and project out
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class PerceiverResampler(nn.Module):
def __init__(
self,
*,
dim,
depth=2,
dim_context=None,
num_latents=32,
dim_head=64,
heads=8,
ff_mult=4,
use_flash_attn=False,
):
super().__init__()
dim_context = default(dim_context, dim)
# Project context to query dimension if different
self.proj_context = (
nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity()
)
# Learnable latent queries
self.latents = nn.Parameter(torch.randn(num_latents, dim))
nn.init.normal_(self.latents, std=0.02) # Initialize latents
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
# Cross-Attention from latents (queries) to context (keys/values)
Attention(
dim=dim,
dim_context=dim, # Context is projected to dim
dim_head=dim_head,
heads=heads,
use_flash=use_flash_attn,
cross_attn_include_queries=False, # Standard Perceiver cross-attn
),
# Self-Attention within latents
# Optional: Add self-attention block here if needed
# Attention(
# dim=dim, dim_head=dim_head, heads=heads, use_flash=use_flash_attn
# ),
# FeedForward block
FeedForward(dim=dim, mult=ff_mult),
]
)
)
# Add LayerNorms (typically before attention and ff blocks)
# self.layers[-1].insert(0, RMSNorm(dim)) # Pre-Attention Norm
# self.layers[-1].insert(2, RMSNorm(dim)) # Pre-FF Norm
# Using Post-Norm structure as in original reference:
self.layers[-1].insert(1, RMSNorm(dim)) # After Attention
self.layers[-1].append(RMSNorm(dim)) # After FeedForward
# Final normalization of latents
# self.norm = RMSNorm(dim) # Final norm applied inside loop in original? Let's apply at end.
def forward(self, x, mask=None):
# x shape: (b, n_ctx, d_ctx)
batch = x.shape[0]
# Project context
x = self.proj_context(x) # (b, n_ctx, d)
# Repeat latents for batch
latents = repeat(self.latents, "n d -> b n d", b=batch) # (b, n_lat, d)
# Apply layers
# Original structure had norm inside loop, adapting: Attn -> Norm -> FF -> Norm
for attn, norm1, ff, norm2 in self.layers:
# Cross-Attention + Residual
latents_attn = attn(latents, x, mask=mask) # Query: latents, Context: x
latents = norm1(latents_attn + latents)
# FeedForward + Residual
latents_ff = ff(latents)
latents = norm2(latents_ff + latents)
# return self.norm(latents) # Apply final norm if defined outside loop
return latents # Return latents after last block norm
# ===============================================================
# End: Content from sparktts/modules/speaker/perceiver_encoder.py
# ===============================================================
# ===============================================================
# Start: Content from sparktts/modules/speaker/speaker_encoder.py
# ===============================================================
class SpeakerEncoder(nn.Module):
"""
Speaker Encoder using ECAPA-TDNN, Perceiver Resampler, and Residual FSQ.
Args:
input_dim (int): acoustic feature dimension (e.g., mel bins)
out_dim (int): output dimension of the final d-vector
latent_dim (int): latent dimension for perceiver and quantization
token_num (int): number of latent tokens from perceiver
fsq_levels (List[int]): levels for finite scalar quantization
fsq_num_quantizers (int): number of residual quantizers in FSQ
ecapa_embed_dim (int): embedding dimension from ECAPA-TDNN (before projection)
"""
def __init__(
self,
input_dim: int = 80, # Default mel bins
out_dim: int = 1024, # Target d-vector dim from config
latent_dim: int = 128, # Latent dim for perceiver/quantizer
token_num: int = 32, # Number of speaker tokens
fsq_levels: List[int] = [4, 4, 4, 4, 4, 4],
fsq_num_quantizers: int = 1,
# Add ECAPA config params if needed, or use defaults
ecapa_channels: int = 512,
ecapa_embed_dim: int = 192, # Default ECAPA embed dim
):
super(SpeakerEncoder, self).__init__()
# ECAPA-TDNN for initial feature extraction and x-vector (optional)
# Using the GLOB variant as in the original __main__ test
self.speaker_encoder_base = ECAPA_TDNN_GLOB_c512(
feat_dim=input_dim,
embed_dim=ecapa_embed_dim # Use specific ECAPA embed dim
)
# Dimension of features extracted by ECAPA (latent before pooling)
ecapa_feature_dim = ecapa_channels * 3 # From concatenation in ECAPA
# Perceiver Resampler to get fixed-length sequence from variable-length ECAPA features
self.perceiver_sampler = PerceiverResampler(
dim=latent_dim, # Output dim of perceiver latents
dim_context=ecapa_feature_dim, # Input dim from ECAPA features
num_latents=token_num,
depth=2, # Default depth, adjust if needed
dim_head=64, heads=8, ff_mult=4, # Default attention/ff params
use_flash_attn=True # Enable flash attention if available
)
# Residual Finite Scalar Quantizer
self.quantizer = ResidualFSQ(
levels=fsq_levels,
num_quantizers=fsq_num_quantizers,
dim=latent_dim, # Quantizer operates on perceiver output dim
channel_first=False, # Perceiver output is (B, T, D), so channel_first=False
quantize_dropout=False, # No dropout specified in config
)
# Final projection from flattened quantized tokens to the target output dimension
self.project = nn.Linear(latent_dim * token_num, out_dim)
def get_codes_from_indices(self, indices: torch.Tensor) -> torch.Tensor:
"""Reconstruct quantized vectors from indices."""
# indices shape: (B, T_token, Q) or (B, Q, T_token)? Check ResidualFSQ output.
# Assuming (B, T_token, Q) from forward pass.
# get_output_from_indices expects (indices_chunk1, indices_chunk2, ...) if grouped.
# If not grouped, expects (B, ... Q). Let's assume (B, T_token, Q).
zq = self.quantizer.get_output_from_indices(indices)
# Output zq shape should be (B, T_token, latent_dim)
return zq
def get_indices(self, mels: torch.Tensor) -> torch.Tensor:
"""Get FSQ indices directly from mel spectrograms."""
# mels: (B, T_mel, D_mel)
_, features = self.speaker_encoder_base(mels, return_latent=True) # features: (B, ecapa_feat_dim, T_feat)
x = self.perceiver_sampler(features.transpose(1, 2)) # Input: (B, T_feat, ecapa_feat_dim), Output: (B, token_num, latent_dim)
_, indices = self.quantizer(x) # Input: (B, token_num, latent_dim), indices: (B, token_num, Q)
return indices
def forward(self, mels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
mels: (B, T_mel, D_mel) - Mel spectrogram input
Return:
x_vector: (B, ecapa_embed_dim) - Global speaker embedding from ECAPA
d_vector: (B, out_dim) - Speaker embedding derived from quantized tokens
"""
# Get base speaker embedding (x-vector) and intermediate features from ECAPA
x_vector, features = self.speaker_encoder_base(mels, return_latent=True)
# features shape: (B, ecapa_feat_dim, T_feat)
# Resample features using Perceiver
# Perceiver expects (B, T, D), so transpose features
perceiver_latents = self.perceiver_sampler(features.transpose(1, 2))
# perceiver_latents shape: (B, token_num, latent_dim)
# Quantize the perceiver latents
# Quantizer expects (B, T, D) if channel_first=False
zq, indices = self.quantizer(perceiver_latents)
# zq shape: (B, token_num, latent_dim), indices shape: (B, token_num, Q)
# Flatten quantized tokens and project to final d-vector dimension
zq_flat = rearrange(zq, 'b t d -> b (t d)') # (B, token_num * latent_dim)
d_vector = self.project(zq_flat) # (B, out_dim)
return x_vector, d_vector
def tokenize(self, mels: torch.Tensor) -> torch.Tensor:
"""Tokenize the input mel spectrogram to get FSQ indices."""
# Same logic as get_indices
_, features = self.speaker_encoder_base(mels, return_latent=True) # features: (B, ecapa_feat_dim, T_feat)
x = self.perceiver_sampler(features.transpose(1, 2)) # (B, token_num, latent_dim)
_, indices = self.quantizer(x) # indices: (B, token_num, Q)
return indices
def detokenize(self, indices: torch.Tensor) -> torch.Tensor:
"""Detokenize FSQ indices to get the final d-vector."""
# indices shape: (B, token_num, Q)
# Reconstruct quantized vectors from indices
zq = self.get_codes_from_indices(indices) # (B, token_num, latent_dim)
# Flatten and project
zq_flat = rearrange(zq, 'b t d -> b (t d)')
d_vector = self.project(zq_flat)
return d_vector
# ===============================================================
# End: Content from sparktts/modules/speaker/speaker_encoder.py
# ===============================================================
# ===============================================================
# Start: Content from sparktts/modules/vq/factorized_vector_quantize.py
# ===============================================================
# Helper function from layers.py (already defined)
# def WNConv1d(*args, **kwargs):
# return weight_norm(nn.Conv1d(*args, **kwargs))
def ema_inplace(moving_avg, new, decay):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
class FactorizedVectorQuantize(nn.Module):
def __init__(
self,
input_dim: int,
codebook_size: int,
codebook_dim: int,
commitment: float,
codebook_loss_weight: float = 1.0,
decay: float = 0.99,
threshold_ema_dead_code: float = 2.0, # Changed default from 2 based on config
momentum: float = 0.99, # Not used in current implementation?
use_l2_normlize: bool = True, # Added from config
**kwargs,
):
super().__init__()
self.input_dim = input_dim
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.commitment = commitment
self.codebook_loss_weight = codebook_loss_weight
self.decay = decay
self.threshold_ema_dead_code = threshold_ema_dead_code
# self.momentum = momentum # Store if needed later
self.use_l2_normlize = use_l2_normlize
if input_dim != self.codebook_dim:
self.in_project = WNConv1d(input_dim, self.codebook_dim, kernel_size=1)
self.out_project = WNConv1d(self.codebook_dim, input_dim, kernel_size=1)
else:
self.in_project = nn.Identity()
self.out_project = nn.Identity()
# Codebook embedding layer
self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
# Initialize codebook? Often random init is fine.
# Buffers for EMA updates (cluster size and maybe embeddings)
self.register_buffer("cluster_size", torch.zeros(self.codebook_size))
# EMA average embeddings (optional, can use self.codebook.weight directly for loss)
# self.register_buffer("ema_embed", self.codebook.weight.clone())
def forward(self, z: torch.Tensor) -> Dict[str, Any]:
"""Quantizes the input tensor using a fixed codebook and returns
the corresponding codebook vectors and losses.
Parameters
----------
z : Tensor[B x D_in x T]
Returns
-------
Dict containing:
z_q (Tensor[B x D_in x T]): Quantized continuous representation (passed through out_project)
indices (Tensor[B x T]): Codebook indices
vq_loss (Tensor[1]): Combined VQ loss (codebook + commitment)
perplexity (Tensor[1]): Codebook perplexity metric
active_num (Tensor[1]): Number of active codebook entries
"""
# z: (B, D_in, T)
B, _, T = z.shape
# Project input to codebook dimension if necessary
z_e = self.in_project(z) # (B, D_code, T)
# Find nearest neighbors and get quantized vectors + indices
z_q, indices, dists = self.decode_latents(z_e) # z_q: (B, D_code, T), indices: (B, T)
# Calculate statistics for perplexity and active codes
with torch.no_grad(): # Stats should not contribute to gradient
embed_onehot = F.one_hot(indices, self.codebook_size).type(z_e.dtype) # (B, T, C)
# Flatten batch and time dims for stats
embed_onehot_flat = rearrange(embed_onehot, 'b t c -> (b t) c')
avg_probs = torch.mean(embed_onehot_flat, dim=0) # (C,)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
# EMA update for cluster size (only in training)
active_num_tensor = (embed_onehot_flat.sum(0) > 0).sum() # Before EMA
if self.training:
# Perform EMA update in place
ema_inplace(self.cluster_size, embed_onehot_flat.sum(0), self.decay)
# Calculate active codes based on EMA threshold
active_num_tensor = (self.cluster_size > self.threshold_ema_dead_code).sum()
# Calculate losses (only in training)
commit_loss = torch.tensor(0.0, device=z.device)
codebook_loss = torch.tensor(0.0, device=z.device)
vq_loss = torch.tensor(0.0, device=z.device)
if self.training:
# Commitment loss (encourage encoder output E(x) to be close to codebook z_q)
# Use z_e (projected encoder output) and z_q.detach()
commit_loss = F.mse_loss(z_e, z_q.detach()) * self.commitment
# Codebook loss (encourage codebook entries z_q to be close to encoder output E(x))
# Use z_q and z_e.detach()
codebook_loss = F.mse_loss(z_q, z_e.detach()) * self.codebook_loss_weight
vq_loss = commit_loss + codebook_loss
# Straight-through estimator: copy gradient from z_q to z_e
z_q_st = z_e + (z_q - z_e).detach()
# Project quantized vectors back to input dimension if necessary
z_q_out = self.out_project(z_q_st) # (B, D_in, T)
return {
"z_q": z_q_out,
"indices": indices,
# "dists": dists, # Dists might be large, exclude unless needed
"vq_loss": vq_loss,
"perplexity": perplexity,
"active_num": active_num_tensor.float(),
}
def embed_code(self, embed_id):
"""Retrieve codebook vectors for given indices."""
return F.embedding(embed_id, self.codebook.weight)
def decode_code(self, embed_id):
"""Retrieve codebook vectors and transpose to (B, D, T) format."""
# embed_id: (B, T)
# Embedding: (B, T, D_code)
# Transpose: (B, D_code, T)
return self.embed_code(embed_id).transpose(1, 2)
def decode_latents(self, latents):
"""Find nearest codebook entries for latent vectors."""
# latents: (B, D_code, T)
B, D_code, T = latents.shape
encodings = rearrange(latents, "b d t -> (b t) d") # ((B*T), D_code)
codebook = self.codebook.weight # (C, D_code)
# Normalize if required
if self.use_l2_normlize:
encodings = F.normalize(encodings, p=2, dim=-1)
codebook = F.normalize(codebook, p=2, dim=-1)
# Compute distances (squared Euclidean or Cosine depending on normalization)
# dist = torch.cdist(encodings, codebook, p=2)**2 # Squared Euclidean
# Faster calculation using matrix multiplication if normalized:
# dist = 2 - 2 * (encodings @ codebook.t())
# Or full squared Euclidean:
dist = (
encodings.pow(2).sum(1, keepdim=True) # (B*T, 1)
- 2 * (encodings @ codebook.t()) # (B*T, C)
+ codebook.pow(2).sum(1, keepdim=True).t() # (1, C)
) # Result shape: (B*T, C)
# Find nearest neighbors
indices = torch.argmin(dist, dim=-1) # (B*T)
indices = rearrange(indices, "(b t) -> b t", b=B) # (B, T)
# Get the quantized vectors
z_q = self.decode_code(indices) # (B, D_code, T)
return z_q, indices, dist # Return dist if needed, e.g., for debugging
# --- Methods for inference/tokenization ---
def tokenize(self, z: torch.Tensor) -> torch.Tensor:
"""Tokenize the input tensor without loss calculation."""
# z: (B, D_in, T)
z_e = self.in_project(z) # (B, D_code, T)
_, indices, _ = self.decode_latents(z_e) # indices: (B, T)
return indices
def detokenize(self, indices: torch.Tensor) -> torch.Tensor:
"""Detokenize indices to quantized vectors in input dimension."""
# indices: (B, T)
z_q_code_dim = self.decode_code(indices) # (B, D_code, T)
z_q_out = self.out_project(z_q_code_dim) # (B, D_in, T)
return z_q_out
# ===============================================================
# End: Content from sparktts/modules/vq/factorized_vector_quantize.py
# ===============================================================
# --- BiCodec Model Definition (Adapted from sparktts/models/bicodec.py) ---
class BiCodec(nn.Module):
"""
BiCodec model for speech synthesis, incorporating a speaker encoder, feature encoder/decoder,
quantizer, and wave generator.
"""
def __init__(
self,
mel_params: Dict[str, Any],
encoder: nn.Module,
decoder: nn.Module,
quantizer: nn.Module,
speaker_encoder: nn.Module,
prenet: nn.Module,
postnet: nn.Module,
**kwargs
) -> None:
"""
Initializes the BiCodec model with the required components.
Args:
mel_params (dict): Parameters for the mel-spectrogram transformer.
encoder (nn.Module): Encoder module.
decoder (nn.Module): Decoder module.
quantizer (nn.Module): Quantizer module.
speaker_encoder (nn.Module): Speaker encoder module.
prenet (nn.Module): Prenet network.
postnet (nn.Module): Postnet network.
"""
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.quantizer = quantizer
self.speaker_encoder = speaker_encoder
self.prenet = prenet
self.postnet = postnet
self._init_mel_transformer(mel_params)
@classmethod
def load_from_config_and_checkpoint(cls, model_dir: Path, config_dict: Dict[str, Any], **kwargs) -> "BiCodec":
"""Loads the model from a config dictionary and checkpoint file."""
ckpt_path = model_dir / 'model.safetensors'
if not ckpt_path.is_file():
raise FileNotFoundError(f"BiCodec checkpoint not found at {ckpt_path}")
audio_tokenizer_config = config_dict # Assuming config_dict holds the relevant sub-config
# Instantiate components using classes from _modeling_bicodec_components
mel_params = audio_tokenizer_config.get("mel_params", {})
encoder_cfg = audio_tokenizer_config.get("encoder", {})
quantizer_cfg = audio_tokenizer_config.get("quantizer", {})
prenet_cfg = audio_tokenizer_config.get("prenet", {})
postnet_cfg = audio_tokenizer_config.get("postnet", {})
decoder_cfg = audio_tokenizer_config.get("decoder", {}) # This corresponds to WaveGenerator
speaker_encoder_cfg = audio_tokenizer_config.get("speaker_encoder", {})
# --- Input Validation ---
required_keys = {
"encoder": ["input_channels", "vocos_dim", "vocos_intermediate_dim", "vocos_num_layers", "out_channels"],
"quantizer": ["input_dim", "codebook_size", "codebook_dim", "commitment"],
"prenet": ["input_channels", "vocos_dim", "vocos_intermediate_dim", "vocos_num_layers", "out_channels"],
"postnet": ["input_channels", "vocos_dim", "vocos_intermediate_dim", "vocos_num_layers", "out_channels"],
"decoder": ["input_channel", "channels", "rates", "kernel_sizes"], # WaveGenerator keys
"speaker_encoder": ["input_dim", "out_dim", "latent_dim", "token_num"],
"mel_params": ["sample_rate", "n_fft", "win_length", "hop_length", "num_mels"]
}
for comp, keys in required_keys.items():
cfg = audio_tokenizer_config.get(comp, {})
if not cfg: logging.get_logger(__name__).warning(f"BiCodec config missing section: '{comp}'")
for key in keys:
if key not in cfg:
logging.get_logger(__name__).warning(f"BiCodec config missing key '{key}' in section '{comp}'")
# --- End Validation ---
# Instantiate modules
encoder = Encoder(**encoder_cfg) if encoder_cfg else None
quantizer = FactorizedVectorQuantize(**quantizer_cfg) if quantizer_cfg else None
prenet = Decoder(**prenet_cfg) if prenet_cfg else None
postnet = Decoder(**postnet_cfg) if postnet_cfg else None
decoder = WaveGenerator(**decoder_cfg) if decoder_cfg else None # WaveGenerator instance
speaker_encoder = SpeakerEncoder(**speaker_encoder_cfg) if speaker_encoder_cfg else None
# Check if all components were successfully created
if not all([encoder, quantizer, prenet, postnet, decoder, speaker_encoder, mel_params]):
raise ValueError("Failed to initialize one or more BiCodec components due to missing configuration.")
# Create the BiCodec instance
model = cls(
mel_params=mel_params,
encoder=encoder,
decoder=decoder, # Pass WaveGenerator instance as decoder
quantizer=quantizer,
speaker_encoder=speaker_encoder,
prenet=prenet,
postnet=postnet,
)
# Load state dict
try:
state_dict = load_file(ckpt_path, device="cpu") # Load to CPU first
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if missing_keys:
print(f"BiCodec missing keys: {missing_keys}")
if unexpected_keys:
print(f"BiCodec unexpected keys: {unexpected_keys}")
except Exception as e:
raise IOError(f"Error loading BiCodec state dict from {ckpt_path}: {e}")
model.eval()
# model.remove_weight_norm() # Assuming this method exists in components
return model
def _init_mel_transformer(self, config: Dict[str, Any]):
# Ensure required keys exist with defaults
sr = config.get("sample_rate", 16000)
n_fft = config.get("n_fft", 1024)
win_length = config.get("win_length", n_fft)
hop_length = config.get("hop_length", n_fft // 4)
fmin = config.get("mel_fmin", 0)
fmax = config.get("mel_fmax", None)
n_mels = config.get("num_mels", 80)
power = config.get("power", 2.0) # Typically 2.0 for power spectrogram
norm = config.get("norm", "slaney")
mel_scale = config.get("mel_scale", "htk") # htk or slaney
self.mel_transformer = TT.MelSpectrogram(
sample_rate=sr,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
f_min=fmin,
f_max=fmax,
n_mels=n_mels,
power=power,
norm=norm,
mel_scale=mel_scale,
).eval() # Set to eval mode
def remove_weight_norm(self):
"""Removes weight normalization from components that support it."""
def _remove_wn(m):
if hasattr(m, 'remove_weight_norm'):
m.remove_weight_norm()
elif isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)):
try:
remove_weight_norm(m)
except ValueError:
pass # Module might not have weight norm applied
self.apply(_remove_wn)
@torch.no_grad()
def tokenize(self, feat: torch.Tensor, ref_wav: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
""" Tokenizes input features and reference wav into semantic and global tokens. """
# Ensure models are on the correct device
device = feat.device
self.mel_transformer.to(device)
self.encoder.to(device)
self.quantizer.to(device)
self.speaker_encoder.to(device)
# feat: (B, D_feat, T_feat), ref_wav: (B, T_wav)
mel = self.mel_transformer(ref_wav) # (B, D_mel, T_mel)
# Encode features to get latents for semantic tokens
z = self.encoder(feat) # (B, D_latent, T_latent) - Assuming Encoder output matches quantizer input dim
# Quantize latents to get semantic tokens (indices)
semantic_tokens = self.quantizer.tokenize(z) # (B, T_latent)
# Encode mel spectrogram to get global tokens (indices)
# SpeakerEncoder.tokenize expects (B, T_mel, D_mel)
global_tokens = self.speaker_encoder.tokenize(mel.transpose(1, 2)) # (B, T_token, Q) - Check shape
# Note: Original BiCodecTokenizer returned (global_tokens, semantic_tokens)
# Let's stick to that order for consistency with original SparkTTS usage.
return global_tokens, semantic_tokens
@torch.no_grad()
def detokenize(self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor) -> torch.Tensor:
""" Detokenizes semantic and global tokens into a waveform. """
# Ensure models are on the correct device
device = semantic_tokens.device # Assume tokens are on target device
self.quantizer.to(device)
self.speaker_encoder.to(device)
self.prenet.to(device)
self.decoder.to(device) # WaveGenerator
# semantic_tokens: (B, T_latent) or (B, T_latent, Q)? Check quantizer.tokenize output shape. Assuming (B, T_latent).
# global_tokens: (B, T_token, Q) - Check speaker_encoder.tokenize output shape.
# Reconstruct quantized vectors from semantic tokens
z_q = self.quantizer.detokenize(semantic_tokens) # (B, D_latent, T_latent)
# Reconstruct d-vector (condition) from global tokens
# SpeakerEncoder.detokenize expects (B, T_token, Q)
d_vector = self.speaker_encoder.detokenize(global_tokens) # (B, D_dvector)
# Apply prenet conditioned on d-vector
# Prenet (Decoder class) expects input (B, D_latent, T_latent) and condition (B, D_dvector)
x = self.prenet(z_q, d_vector) # (B, D_prenet_out, T_latent) - Assuming prenet maintains time dim
# Add condition (broadcasted) before wave generation - Check original logic
# Ensure d_vector has correct shape for broadcasting
if d_vector.ndim == 2:
d_vector_unsqueezed = d_vector.unsqueeze(-1) # (B, D_dvector, 1)
else: # Should not happen if speaker_encoder outputs (B, D)
d_vector_unsqueezed = d_vector
# Ensure dimensions match for addition
if x.shape[1] == d_vector_unsqueezed.shape[1]:
# Broadcast d_vector across time dimension T_latent
x = x + d_vector_unsqueezed
else:
# Maybe project d_vector or x? Log a warning or adapt based on expected dims.
logging.get_logger(__name__).warning(f"Prenet output dim {x.shape[1]} != d-vector dim {d_vector_unsqueezed.shape[1]}. Skipping residual connection.")
# Generate waveform using the decoder (WaveGenerator)
# WaveGenerator expects (B, D_input, T_input)
wav_recon = self.decoder(x) # (B, 1, T_wav)
return wav_recon
# --- Main SparkTTS Model ---
from .configuration_spark_tts import SparkTTSConfig
# from ._utils import load_audio # Use utils from _utils.py
logger = logging.get_logger(__name__)
class SparkTTSModel(PreTrainedModel, GenerationMixin):
"""
SparkTTS model integrating LLM, BiCodec, and Wav2Vec2 for text-to-speech.
"""
config_class = SparkTTSConfig
base_model_prefix = "spark_tts"
_supports_load_fast = False
def __init__(self, config: SparkTTSConfig, llm=None, wav2vec2_model=None, wav2vec2_processor=None, bicodec=None):
super().__init__(config)
self.config = config
self.llm = llm
self.wav2vec2_model = wav2vec2_model
self.wav2vec2_processor = wav2vec2_processor
self.bicodec = bicodec
# Wav2Vec2 specific config adjustment (needs to happen after loading)
if self.wav2vec2_model and hasattr(self.wav2vec2_model.config, 'output_hidden_states'):
self.wav2vec2_model.config.output_hidden_states = True
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: bool = None,
**kwargs,
):
# 1. Load Config
if config is None:
config, model_kwargs = cls.config_class.from_pretrained(
pretrained_model_name_or_path,
*model_args,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
return_unused_kwargs=True,
**kwargs,
)
else:
model_kwargs = kwargs
# Pop device map info - will handle placement later
device_map = model_kwargs.pop("device_map", None)
torch_dtype = model_kwargs.pop("torch_dtype", "auto") # Use config's or auto
# Check for trust_remote_code - needed for config loading if custom code involved there too
trust_remote_code = model_kwargs.pop("trust_remote_code", False) # Important
# Determine actual model directory (could be cache path)
if pretrained_model_name_or_path is not None:
resolved_model_path = Path(pretrained_model_name_or_path)
if not resolved_model_path.is_dir():
# Attempt to download and resolve cache path if it's an ID
# This requires internet connection if not cached
try:
resolved_model_path = Path(cached_file(
pretrained_model_name_or_path,
filename=cls.config_class.config_files[0], # e.g., "config.json"
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)).parent
except Exception as e:
logger.warning(f"Could not resolve cache path for {pretrained_model_name_or_path}: {e}. Assuming it's a local path.")
resolved_model_path = Path(pretrained_model_name_or_path) # Fallback
if not resolved_model_path.is_dir():
raise EnvironmentError(f"Cannot find model directory at {resolved_model_path}")
else:
raise ValueError("pretrained_model_name_or_path must be provided.")
# Helper function to resolve paths relative to the main model directory
def _resolve_path(sub_path):
p = Path(sub_path)
if p.is_absolute():
return str(p)
else:
# Resolve relative to the potentially cached main model path
return str(resolved_model_path / p)
# --- Load LLM ---
llm_path = _resolve_path(config.llm_model_name_or_path)
logger.info(f"Loading LLM from resolved path: {llm_path}")
try:
llm = AutoModelForCausalLM.from_pretrained(
llm_path,
torch_dtype=torch_dtype if torch_dtype != "auto" else config.torch_dtype, # Prioritize explicit dtype
trust_remote_code=trust_remote_code, # Pass down trust_remote_code
**model_kwargs # Pass remaining kwargs
)
except Exception as e:
raise OSError(f"Failed to load LLM from {llm_path}: {e}")
# --- Load Wav2Vec2 ---
w2v_path = _resolve_path(config.wav2vec2_model_name_or_path)
logger.info(f"Loading Wav2Vec2 from resolved path: {w2v_path}")
try:
wav2vec2_processor = Wav2Vec2FeatureExtractor.from_pretrained(w2v_path, trust_remote_code=trust_remote_code)
wav2vec2_model = Wav2Vec2Model.from_pretrained(w2v_path, trust_remote_code=trust_remote_code)
except Exception as e:
raise OSError(f"Failed to load Wav2Vec2 components from {w2v_path}: {e}")
# --- Load BiCodec ---
bicodec_path = _resolve_path(config.bicodec_model_name_or_path)
logger.info(f"Loading BiCodec from resolved path: {bicodec_path}")
# print(f"Loading BiCodec from resolved path: {bicodec_path}, {config}")
if not config.bicodec_config or "audio_tokenizer" not in config.bicodec_config:
raise ValueError("BiCodec configuration ('bicodec_config' with 'audio_tokenizer' key) not found in SparkTTSConfig.")
try:
# Assuming BiCodec class is defined above in this file
bicodec = BiCodec.load_from_config_and_checkpoint(
model_dir=Path(bicodec_path),
config_dict=config.bicodec_config["audio_tokenizer"]
)
except Exception as e:
raise OSError(f"Failed to load BiCodec from {bicodec_path}: {e}")
# Instantiate the main model wrapper, passing the loaded components
model = cls(config, llm=llm, wav2vec2_model=wav2vec2_model, wav2vec2_processor=wav2vec2_processor, bicodec=bicodec)
# --- Handle device placement ---
# Note: device_map is complex; simple .to(device) is easier if not using accelerate
# Determine target device
if torch.cuda.is_available():
current_device = torch.cuda.current_device()
device = torch.device(f"cuda:{current_device}")
else:
device = torch.device("cpu")
logger.info(f"Placing SparkTTSModel and components on device: {device}")
model.to(device) # This should move all registered nn.Module attributes
return model
# --- Embedding getters/setters (delegate to LLM if loaded) ---
def get_input_embeddings(self):
if self.llm:
return self.llm.get_input_embeddings()
return None # Or raise error
def set_input_embeddings(self, value):
if self.llm:
self.llm.set_input_embeddings(value)
else:
logger.warning("LLM not loaded, cannot set input embeddings.")
def get_output_embeddings(self):
if self.llm:
# For causal LM, output embeddings are usually tied to lm_head
return self.llm.get_output_embeddings()
return None # Or raise error
def set_output_embeddings(self, new_embeddings):
if self.llm and hasattr(self.llm, 'set_output_embeddings'):
self.llm.set_output_embeddings(new_embeddings)
else:
logger.warning("LLM not loaded or does not support set_output_embeddings.")
# --- End Embedding methods ---
# post_init is less critical now as loading happens in from_pretrained,
# but can be used for final checks or setup.
def post_init(self):
# Ensure wav2vec2 config has output_hidden_states=True
if self.wav2vec2_model and hasattr(self.wav2vec2_model.config, 'output_hidden_states'):
if not self.wav2vec2_model.config.output_hidden_states:
self.wav2vec2_model.config.output_hidden_states = True
logger.info("Set wav2vec2_model.config.output_hidden_states=True")
@property
def device(self) -> torch.device:
""" Override device property to report the LLM's device as representative """
if self.llm:
return self.llm.device
else:
# Fallback or default if LLM not loaded yet
# This might be called by pipeline before full init? Be cautious.
try:
return next(self.parameters()).device
except StopIteration:
# If no parameters, default to CPU
return torch.device("cpu")
@torch.no_grad()
def _extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor:
"""Extract wav2vec2 features. Input wavs: (B, T_wav)"""
if not self.wav2vec2_model or not self.wav2vec2_processor:
raise RuntimeError("Wav2Vec2 components not loaded.")
# Use component's device
target_device = self.wav2vec2_model.device
wavs_on_device = wavs.to(target_device) # Expected shape [B, T_wav] e.g., [1, 61120]
# Process audio using the Wav2Vec2FeatureExtractor
processor_output = self.wav2vec2_processor(
wavs_on_device,
sampling_rate=self.config.sample_rate,
return_tensors="pt",
padding=True, # Ensure padding is handled correctly
)
inputs = processor_output.input_values # Should be shape [B, T_processed]
# --- START DEBUG & FIX ---
print(f"Shape returned by processor: {inputs.shape}")
# Reshape if processor added extra dimensions
if inputs.ndim == 4 and inputs.shape[1] == 1 and inputs.shape[2] == 1:
print(f"Reshaping input from {inputs.shape} to 2D.")
inputs = inputs.squeeze(1).squeeze(1) # Remove the two middle dimensions
elif inputs.ndim == 3 and inputs.shape[1] == 1:
print(f"Reshaping input from {inputs.shape} to 2D.")
inputs = inputs.squeeze(1) # Remove the channel dimension
# Ensure final shape is 2D: (batch_size, sequence_length)
if inputs.ndim != 2:
raise ValueError(f"Unexpected shape after processing/reshaping: {inputs.shape}. Expected 2D input for Wav2Vec2Model.")
print(f"Shape BEFORE Wav2Vec2Model: {inputs.shape}")
# --- END DEBUG & FIX ---
inputs = inputs.to(target_device)
# Ensure output_hidden_states=True during call if not set reliably in config
outputs = self.wav2vec2_model(inputs, output_hidden_states=True)
if outputs.hidden_states is None:
raise ValueError("Wav2Vec2 model did not return hidden states. Ensure config.output_hidden_states=True.")
# Mix specific layers
num_layers = len(outputs.hidden_states)
indices_to_mix = [11, 14, 16]
valid_indices = [i for i in indices_to_mix if i < num_layers]
if len(valid_indices) != len(indices_to_mix):
logger.warning(f"Requested Wav2Vec2 hidden state indices {indices_to_mix} out of range (0-{num_layers-1}). Using available valid indices: {valid_indices}.")
if not valid_indices: # If no valid indices, use last hidden state
logger.warning("No valid hidden state indices for mixing. Using last hidden state.")
feats_mix = outputs.last_hidden_state
else:
# Mix available valid indices
feats_mix = torch.stack([outputs.hidden_states[i] for i in valid_indices]).mean(dim=0)
else:
# Original mixing logic
feats_mix = (outputs.hidden_states[11] + outputs.hidden_states[14] + outputs.hidden_states[16]) / 3
# Output shape: (B, T_feat, D_feat) - Transpose needed for BiCodec Encoder
return feats_mix.transpose(1, 2) # (B, D_feat, T_feat)
def _get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
"""Get reference audio clip for speaker embedding."""
ref_samples = int(self.config.sample_rate * self.config.ref_segment_duration)
latent_hop_length = self.config.latent_hop_length
# Ensure length is multiple of hop_length for potential downstream processing
ref_segment_length = max(latent_hop_length, (ref_samples // latent_hop_length) * latent_hop_length) # Ensure at least one hop
wav_length = len(wav)
if wav_length == 0: # Handle empty input
return np.zeros(ref_segment_length, dtype=np.float32)
if ref_segment_length > wav_length:
num_repeats = (ref_segment_length // wav_length) + 1
wav = np.tile(wav, num_repeats)
return wav[:ref_segment_length].astype(np.float32) # Ensure float32
@torch.no_grad()
def _tokenize_audio(self, audio_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""Load audio, extract features, and tokenize using BiCodec."""
wav_np = load_audio(
audio_path,
sampling_rate=self.config.sample_rate,
volume_normalize=self.config.volume_normalize,
)
wav_ref_np = self._get_ref_clip(wav_np)
# Convert to tensors, add batch dim, move to device
wav = torch.from_numpy(wav_np).unsqueeze(0).float().to(self.device)
ref_wav = torch.from_numpy(wav_ref_np).unsqueeze(0).float().to(self.device)
# Extract Wav2Vec2 features -> (B, D_feat, T_feat)
feat = self._extract_wav2vec2_features(wav)
# Tokenize using BiCodec -> returns (global_tokens, semantic_tokens)
# BiCodec.tokenize expects feat: (B, D_feat, T_feat), ref_wav: (B, T_wav)
global_tokens, semantic_tokens = self.bicodec.tokenize(feat, ref_wav)
# global_tokens: (B, T_token, Q), semantic_tokens: (B, T_latent)
return global_tokens, semantic_tokens
@torch.no_grad()
def _detokenize_audio(self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor) -> np.ndarray:
"""Detokenize using BiCodec to get waveform."""
global_tokens = global_tokens.to(self.device)
semantic_tokens = semantic_tokens.to(self.device)
self.bicodec.to(self.device) # Ensure BiCodec is on device
# BiCodec.detokenize expects global_tokens: (B, T_token, Q), semantic_tokens: (B, T_latent)
wav_rec = self.bicodec.detokenize(global_tokens, semantic_tokens) # (B, 1, T_wav)
# Remove channel dim and batch dim, convert to numpy
return wav_rec.detach().squeeze(0).squeeze(0).cpu().numpy()
def forward(self, *args, **kwargs):
""" Forward pass delegates to the LLM for generation compatibility, but direct use is not intended for TTS. """
# return self.llm(*args, **kwargs) # Option 1: Delegate fully
logger.warning("Direct forward pass on SparkTTSModel is not the intended use for TTS. Use the generate method or pipeline.")
# Option 2: Minimal implementation for compatibility if needed
if 'input_ids' in kwargs:
return self.llm(input_ids=kwargs['input_ids'])
else:
raise NotImplementedError("SparkTTSModel's forward pass requires 'input_ids' or should not be called directly for TTS.")
# Use GenerationMixin's forward method by default if needed.
# Define prepare_inputs_for_generation if LLM needs specific handling.
def prepare_inputs_for_generation(self, input_ids, **kwargs):
""" Prepares inputs for the LLM's generate method. """
if not self.llm:
raise RuntimeError("LLM component not loaded.")
# --- START REVISED IMPLEMENTATION ---
# Delegate to the LLM's prepare_inputs_for_generation method directly.
# This ensures we use the exact logic defined for the specific LLM architecture (Qwen2).
# It should handle past_key_values, attention_mask, use_cache etc. correctly.
try:
# Pass all relevant kwargs received by the top-level generate call
# The LLM's method will select what it needs.
model_inputs = self.llm.prepare_inputs_for_generation(input_ids, **kwargs)
return model_inputs
except AttributeError:
# Fallback if the LLM doesn't have this method (unlikely for recent models)
logger.warning("LLM does not have 'prepare_inputs_for_generation'. Using basic fallback.")
model_kwargs = {}
model_kwargs["past_key_values"] = kwargs.get("past_key_values", None)
model_kwargs["use_cache"] = kwargs.get("use_cache", None)
# Ensure attention_mask is included if present in kwargs
if "attention_mask" in kwargs:
model_kwargs["attention_mask"] = kwargs["attention_mask"]
return {"input_ids": input_ids, **model_kwargs}
# --- END REVISED IMPLEMENTATION ---
# We need a minimal forward method compatible with GenerationMixin
# It should accept the output of prepare_inputs_for_generation
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
**kwargs # Accept other potential kwargs from prepare_inputs
) -> Any: # Return type depends on the LLM, usually CausalLMOutputWithPast
"""
Minimal forward pass that delegates to the underlying LLM.
Required for compatibility with GenerationMixin.
Accepts arguments typically returned by prepare_inputs_for_generation.
"""
if not self.llm:
raise RuntimeError("LLM component not loaded.")
# Filter arguments for the LLM's forward method
# (Some LLMs might not accept position_ids directly in forward when using past_key_values)
llm_kwargs = {
"past_key_values": past_key_values,
"attention_mask": attention_mask,
**kwargs # Pass through any other relevant kwargs
}
# Only pass position_ids if the LLM's forward signature accepts it
# This requires inspecting the LLM's forward signature or knowing its behavior.
# For simplicity, we might omit it if it causes issues, or handle it more dynamically.
# Let's assume the LLM forward can handle it for now if prepare_inputs included it.
if position_ids is not None:
llm_kwargs["position_ids"] = position_ids
return self.llm(input_ids=input_ids, **llm_kwargs)
# Add generate method to use GenerationMixin capabilities directly on SparkTTSModel if desired
# This will internally call prepare_inputs_for_generation and forward (which might need defining/adjusting)
# However, the pipeline calls self.model.llm.generate, so this might not be strictly needed unless you want `model.generate(...)`
# @torch.no_grad()
# def generate(self, *args, **kwargs):
# if not self.llm:
# raise RuntimeError("LLM component not loaded.")
# # This might need adjustments based on how GenerationMixin interacts with the overridden forward
# # return super().generate(*args, **kwargs) # Calls self.prepare_inputs + self.forward loop
# # Or directly call the LLM's generate if forward is problematic:
# return self.llm.generate(*args, **kwargs)