|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" PyTorch SparkTTS model.""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
import os |
|
import warnings |
|
from pathlib import Path |
|
from typing import Dict, Any, Tuple, Optional, Union |
|
|
|
from transformers import PreTrainedModel, AutoModelForCausalLM, Wav2Vec2FeatureExtractor, Wav2Vec2Model |
|
from transformers.utils import logging, requires_backends |
|
from transformers.generation.utils import GenerationMixin |
|
from transformers.configuration_utils import PretrainedConfig |
|
from safetensors.torch import load_file |
|
import torchaudio.transforms as TT |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" Utility functions for SparkTTS """ |
|
|
|
import random |
|
import soxr |
|
import soundfile |
|
import torch |
|
import torchaudio |
|
import numpy as np |
|
|
|
from pathlib import Path |
|
from typing import Tuple, Dict, Any |
|
from numpy.lib.stride_tricks import sliding_window_view |
|
from omegaconf import OmegaConf |
|
|
|
|
|
|
|
TASK_TOKEN_MAP = { |
|
"vc": "<|task_vc|>", |
|
"tts": "<|task_tts|>", |
|
"asr": "<|task_asr|>", |
|
"s2s": "<|task_s2s|>", |
|
"t2s": "<|task_t2s|>", |
|
"understand": "<|task_understand|>", |
|
"caption": "<|task_cap|>", |
|
"controllable_tts": "<|task_controllable_tts|>", |
|
"prompt_tts": "<|task_prompt_tts|>", |
|
"speech_edit": "<|task_edit|>", |
|
} |
|
|
|
LEVELS_MAP = { |
|
"very_low": 0, |
|
"low": 1, |
|
"moderate": 2, |
|
"high": 3, |
|
"very_high": 4, |
|
} |
|
|
|
LEVELS_MAP_UI = { |
|
1: 'very_low', |
|
2: 'low', |
|
3: 'moderate', |
|
4: 'high', |
|
5: 'very_high' |
|
} |
|
|
|
GENDER_MAP = { |
|
"female": 0, |
|
"male": 1, |
|
} |
|
|
|
|
|
def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray: |
|
temp = np.sort(np.abs(audio)) |
|
if len(temp) == 0: |
|
return audio |
|
if temp[-1] < 0.1: |
|
scaling_factor = max(temp[-1], 1e-3) |
|
audio = audio / scaling_factor * 0.1 |
|
temp = temp[temp > 0.01] |
|
L = temp.shape[0] |
|
if L <= 10: |
|
return audio |
|
volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)]) |
|
if volume == 0: |
|
return audio |
|
audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10) |
|
max_value = np.max(np.abs(audio)) if len(audio) > 0 else 0 |
|
if max_value > 1: |
|
audio = audio / max_value |
|
return audio |
|
|
|
def load_audio( |
|
adfile: Path, |
|
sampling_rate: int = None, |
|
length: int = None, |
|
volume_normalize: bool = False, |
|
segment_duration: int = None, |
|
) -> np.ndarray: |
|
try: |
|
audio, sr = soundfile.read(adfile, dtype='float32') |
|
except Exception as e: |
|
raise IOError(f"Could not read audio file {adfile}: {e}") |
|
|
|
if audio is None or len(audio) == 0: |
|
raise ValueError(f"Audio file {adfile} is empty or invalid.") |
|
|
|
if len(audio.shape) > 1: |
|
audio = audio[:, 0] |
|
|
|
if sampling_rate is not None and sr != sampling_rate: |
|
try: |
|
|
|
audio = audio.astype(np.float64) |
|
audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ") |
|
|
|
audio = audio.astype(np.float32) |
|
sr = sampling_rate |
|
except Exception as e: |
|
raise RuntimeError(f"Failed to resample audio from {sr}Hz to {sampling_rate}Hz: {e}") |
|
|
|
if segment_duration is not None: |
|
seg_length = int(sr * segment_duration) |
|
audio = random_select_audio_segment(audio, seg_length) |
|
|
|
if volume_normalize: |
|
audio = audio_volume_normalize(audio) |
|
|
|
if length is not None: |
|
if audio.shape[0] > length: |
|
audio = audio[:length] |
|
else: |
|
audio = np.pad(audio, (0, int(length - audio.shape[0])), mode='constant') |
|
return audio |
|
|
|
def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray: |
|
if audio.shape[0] < length: |
|
audio = np.pad(audio, (0, int(length - audio.shape[0])), mode='constant') |
|
start_index = 0 |
|
elif audio.shape[0] == length: |
|
start_index = 0 |
|
else: |
|
start_index = random.randint(0, audio.shape[0] - length) |
|
|
|
end_index = int(start_index + length) |
|
return audio[start_index:end_index] |
|
|
|
|
|
def load_config_yaml(config_path: Path) -> Dict: |
|
"""Loads a YAML configuration file using OmegaConf.""" |
|
|
|
if not Path(config_path).is_file(): |
|
raise FileNotFoundError(f"YAML Config file not found: {config_path}") |
|
try: |
|
config = OmegaConf.load(config_path) |
|
|
|
return OmegaConf.to_container(config, resolve=True) |
|
except Exception as e: |
|
raise IOError(f"Error loading YAML config file {config_path}: {e}") |
|
|
|
|
|
""" PyTorch SparkTTS BiCodec sub-module definitions.""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.distributed as dist |
|
import random |
|
|
|
from torch.nn.utils import weight_norm, remove_weight_norm |
|
from torch import Tensor, int32 |
|
from torch.amp import autocast |
|
|
|
from typing import Any, Dict, List, Tuple, Optional |
|
from collections import namedtuple |
|
from functools import wraps, partial |
|
from contextlib import nullcontext |
|
from packaging import version |
|
|
|
from einops import rearrange, repeat, reduce, pack, unpack |
|
from einops.layers.torch import Rearrange |
|
from einx import get_at |
|
|
|
|
|
|
|
|
|
def WNConv1d(*args, **kwargs): |
|
return weight_norm(nn.Conv1d(*args, **kwargs)) |
|
|
|
|
|
def WNConvTranspose1d(*args, **kwargs): |
|
return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) |
|
|
|
|
|
|
|
@torch.jit.script |
|
def snake(x, alpha): |
|
shape = x.shape |
|
x = x.reshape(shape[0], shape[1], -1) |
|
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) |
|
x = x.reshape(shape) |
|
return x |
|
|
|
|
|
class Snake1d(nn.Module): |
|
def __init__(self, channels): |
|
super().__init__() |
|
self.alpha = nn.Parameter(torch.ones(1, channels, 1)) |
|
|
|
def forward(self, x): |
|
return snake(x, self.alpha) |
|
|
|
|
|
class ResidualUnit(nn.Module): |
|
def __init__(self, dim: int = 16, dilation: int = 1): |
|
super().__init__() |
|
pad = ((7 - 1) * dilation) // 2 |
|
self.block = nn.Sequential( |
|
Snake1d(dim), |
|
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), |
|
Snake1d(dim), |
|
WNConv1d(dim, dim, kernel_size=1), |
|
) |
|
|
|
def forward(self, x): |
|
y = self.block(x) |
|
|
|
diff = x.shape[-1] - y.shape[-1] |
|
if diff > 0: |
|
pad = diff // 2 |
|
x = x[..., pad:pad + y.shape[-1]] |
|
elif diff < 0: |
|
pad = -diff // 2 |
|
y = y[..., pad:pad + x.shape[-1]] |
|
|
|
return x + y |
|
|
|
|
|
def init_weights(m): |
|
if isinstance(m, nn.Conv1d): |
|
nn.init.trunc_normal_(m.weight, std=0.02) |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SamplingBlock(nn.Module): |
|
"""Sampling block for upsampling or downsampling""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
groups: int = 1, |
|
upsample_scale: int = 1, |
|
downsample_scale: int = 1, |
|
) -> None: |
|
""" |
|
Args: |
|
dim: input dimension |
|
groups: number of groups |
|
upsample_scale: upsampling scale |
|
downsample_scale: downsampling scale |
|
""" |
|
super(SamplingBlock, self).__init__() |
|
|
|
self.upsample_scale = upsample_scale |
|
self.downsample_scale = downsample_scale |
|
|
|
if self.upsample_scale > 1: |
|
self.de_conv_upsampler = nn.Sequential( |
|
nn.LeakyReLU(0.2), |
|
nn.ConvTranspose1d( |
|
dim, |
|
dim, |
|
kernel_size=upsample_scale * 2, |
|
stride=upsample_scale, |
|
padding=upsample_scale // 2 + upsample_scale % 2, |
|
output_padding=upsample_scale % 2, |
|
groups=groups, |
|
), |
|
) |
|
|
|
if self.downsample_scale > 1: |
|
self.conv_downsampler = nn.Sequential( |
|
nn.LeakyReLU(0.2), |
|
nn.Conv1d( |
|
dim, |
|
dim, |
|
kernel_size=2 * downsample_scale, |
|
stride=downsample_scale, |
|
padding=downsample_scale // 2 + downsample_scale % 2, |
|
groups=groups, |
|
), |
|
) |
|
|
|
@staticmethod |
|
def repeat_upsampler(x, upsample_scale): |
|
return x.repeat_interleave(upsample_scale, dim=2) |
|
|
|
@staticmethod |
|
def skip_downsampler(x, downsample_scale): |
|
return F.avg_pool1d(x, kernel_size=downsample_scale, stride=downsample_scale) |
|
|
|
def forward(self, x): |
|
|
|
|
|
if self.upsample_scale > 1: |
|
repeat_res = self.repeat_upsampler(x, self.upsample_scale) |
|
deconv_res = self.de_conv_upsampler(x) |
|
|
|
if deconv_res.shape[-1] > repeat_res.shape[-1]: |
|
deconv_res = deconv_res[..., :repeat_res.shape[-1]] |
|
elif repeat_res.shape[-1] > deconv_res.shape[-1]: |
|
repeat_res = repeat_res[..., :deconv_res.shape[-1]] |
|
upmerge_res = repeat_res + deconv_res |
|
else: |
|
upmerge_res = x |
|
repeat_res = x |
|
|
|
if self.downsample_scale > 1: |
|
conv_res = self.conv_downsampler(upmerge_res) |
|
skip2_res = self.skip_downsampler(upmerge_res, self.downsample_scale) |
|
skip1_res = self.skip_downsampler(repeat_res, self.downsample_scale) |
|
|
|
min_len = min(conv_res.shape[-1], skip1_res.shape[-1], skip2_res.shape[-1]) |
|
conv_res = conv_res[..., :min_len] |
|
skip1_res = skip1_res[..., :min_len] |
|
skip2_res = skip2_res[..., :min_len] |
|
else: |
|
conv_res = upmerge_res |
|
skip2_res = upmerge_res |
|
skip1_res = repeat_res |
|
|
|
final_res = conv_res + skip1_res + skip2_res |
|
|
|
|
|
|
|
return final_res |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TAP(nn.Module): |
|
""" |
|
Temporal average pooling, only first-order mean is considered |
|
""" |
|
|
|
def __init__(self, in_dim=0, **kwargs): |
|
super(TAP, self).__init__() |
|
self.in_dim = in_dim |
|
|
|
def forward(self, x): |
|
pooling_mean = x.mean(dim=-1) |
|
|
|
pooling_mean = pooling_mean.flatten(start_dim=1) |
|
return pooling_mean |
|
|
|
def get_out_dim(self): |
|
|
|
|
|
|
|
return self.in_dim |
|
|
|
|
|
class TSDP(nn.Module): |
|
""" |
|
Temporal standard deviation pooling, only second-order std is considered |
|
""" |
|
|
|
def __init__(self, in_dim=0, **kwargs): |
|
super(TSDP, self).__init__() |
|
self.in_dim = in_dim |
|
|
|
def forward(self, x): |
|
|
|
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7) |
|
pooling_std = pooling_std.flatten(start_dim=1) |
|
return pooling_std |
|
|
|
def get_out_dim(self): |
|
|
|
|
|
return self.in_dim |
|
|
|
|
|
class TSTP(nn.Module): |
|
""" |
|
Temporal statistics pooling, concatenate mean and std, which is used in |
|
x-vector |
|
Comment: simple concatenation can not make full use of both statistics |
|
""" |
|
|
|
def __init__(self, in_dim=0, **kwargs): |
|
super(TSTP, self).__init__() |
|
self.in_dim = in_dim |
|
|
|
def forward(self, x): |
|
|
|
pooling_mean = x.mean(dim=-1) |
|
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7) |
|
pooling_mean = pooling_mean.flatten(start_dim=1) |
|
pooling_std = pooling_std.flatten(start_dim=1) |
|
stats = torch.cat((pooling_mean, pooling_std), 1) |
|
return stats |
|
|
|
def get_out_dim(self): |
|
|
|
|
|
return self.in_dim * 2 |
|
|
|
|
|
class ASTP(nn.Module): |
|
""" Attentive statistics pooling: Channel- and context-dependent |
|
statistics pooling, first used in ECAPA_TDNN. |
|
""" |
|
|
|
def __init__(self, |
|
in_dim, |
|
bottleneck_dim=128, |
|
global_context_att=False, |
|
**kwargs): |
|
super(ASTP, self).__init__() |
|
self.in_dim = in_dim |
|
self.global_context_att = global_context_att |
|
|
|
|
|
|
|
if global_context_att: |
|
self.linear1 = nn.Conv1d( |
|
in_dim * 3, bottleneck_dim, |
|
kernel_size=1) |
|
else: |
|
self.linear1 = nn.Conv1d( |
|
in_dim, bottleneck_dim, |
|
kernel_size=1) |
|
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, |
|
kernel_size=1) |
|
|
|
def forward(self, x): |
|
""" |
|
x: a 3-dimensional tensor in tdnn-based architecture (B,F,T) |
|
or a 4-dimensional tensor in resnet architecture (B,C,F,T) |
|
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) |
|
""" |
|
if len(x.shape) == 4: |
|
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3]) |
|
assert len(x.shape) == 3 |
|
|
|
if self.global_context_att: |
|
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) |
|
context_std = torch.sqrt( |
|
torch.var(x, dim=-1, keepdim=True) + 1e-7).expand_as(x) |
|
x_in = torch.cat((x, context_mean, context_std), dim=1) |
|
else: |
|
x_in = x |
|
|
|
|
|
alpha = torch.tanh( |
|
self.linear1(x_in)) |
|
alpha = torch.softmax(self.linear2(alpha), dim=2) |
|
mean = torch.sum(alpha * x, dim=2) |
|
var = torch.sum(alpha * (x**2), dim=2) - mean**2 |
|
std = torch.sqrt(var.clamp(min=1e-7)) |
|
return torch.cat([mean, std], dim=1) |
|
|
|
def get_out_dim(self): |
|
|
|
|
|
return self.in_dim * 2 |
|
|
|
|
|
class MHASTP(torch.nn.Module): |
|
""" Multi head attentive statistics pooling |
|
Reference: |
|
Self Multi-Head Attention for Speaker Recognition |
|
https://arxiv.org/pdf/1906.09890.pdf |
|
""" |
|
|
|
def __init__(self, |
|
in_dim, |
|
layer_num=2, |
|
head_num=2, |
|
d_s=1, |
|
bottleneck_dim=64, |
|
**kwargs): |
|
super(MHASTP, self).__init__() |
|
assert (in_dim % head_num |
|
) == 0 |
|
self.in_dim = in_dim |
|
self.head_num = head_num |
|
d_model = int(in_dim / head_num) |
|
channel_dims = [bottleneck_dim for i in range(layer_num + 1)] |
|
if d_s > 1: |
|
d_s = d_model |
|
else: |
|
d_s = 1 |
|
self.d_s = d_s |
|
channel_dims[0], channel_dims[-1] = d_model, d_s |
|
heads_att_trans = [] |
|
for i in range(self.head_num): |
|
att_trans = nn.Sequential() |
|
for j in range(layer_num - 1): |
|
att_trans.add_module( |
|
'att_' + str(j), |
|
nn.Conv1d(channel_dims[j], channel_dims[j + 1], 1, 1)) |
|
att_trans.add_module('tanh' + str(j), nn.Tanh()) |
|
att_trans.add_module( |
|
'att_' + str(layer_num - 1), |
|
nn.Conv1d(channel_dims[layer_num - 1], channel_dims[layer_num], |
|
1, 1)) |
|
heads_att_trans.append(att_trans) |
|
self.heads_att_trans = nn.ModuleList(heads_att_trans) |
|
|
|
def forward(self, input): |
|
""" |
|
input: a 3-dimensional tensor in xvector architecture |
|
or a 4-dimensional tensor in resnet architecture |
|
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) |
|
""" |
|
if len(input.shape) == 4: |
|
input = input.reshape(input.shape[0], |
|
input.shape[1] * input.shape[2], |
|
input.shape[3]) |
|
assert len(input.shape) == 3 |
|
bs, f_dim, t_dim = input.shape |
|
chunks = torch.chunk(input, self.head_num, 1) |
|
|
|
chunks_out = [] |
|
for i, layer in enumerate(self.heads_att_trans): |
|
att_score = layer(chunks[i]) |
|
alpha = F.softmax(att_score, dim=-1) |
|
mean = torch.sum(alpha * chunks[i], dim=2) |
|
var = torch.sum(alpha * chunks[i]**2, dim=2) - mean**2 |
|
std = torch.sqrt(var.clamp(min=1e-7)) |
|
chunks_out.append(torch.cat((mean, std), dim=1)) |
|
out = torch.cat(chunks_out, dim=1) |
|
return out |
|
|
|
def get_out_dim(self): |
|
|
|
|
|
return self.in_dim * 2 |
|
|
|
|
|
class MQMHASTP(torch.nn.Module): |
|
""" An attentive pooling |
|
Reference: |
|
multi query multi head attentive statistics pooling |
|
https://arxiv.org/pdf/2110.05042.pdf |
|
Args: |
|
in_dim: the feature dimension of input |
|
layer_num: the number of layer in the pooling layer |
|
query_num: the number of querys |
|
head_num: the number of heads |
|
bottleneck_dim: the bottleneck dimension |
|
|
|
SA (H = 1, Q = 1, n = 2, d_s = 1) ref: |
|
https://www.danielpovey.com/files/2018_interspeech_xvector_attention.pdf |
|
MHA (H > 1, Q = 1, n = 1, d_s = 1) ref: |
|
https://arxiv.org/pdf/1906.09890.pdf |
|
AS (H = 1, Q > 1, n = 2, d_s = 1) ref: |
|
https://arxiv.org/pdf/1803.10963.pdf |
|
VSA (H = 1, Q > 1, n = 2, d_s = d_h) ref: |
|
http://www.interspeech2020.org/uploadfile/pdf/Mon-2-10-5.pdf |
|
""" |
|
|
|
def __init__(self, |
|
in_dim, |
|
layer_num=2, |
|
query_num=2, |
|
head_num=8, |
|
d_s=2, |
|
bottleneck_dim=64, |
|
**kwargs): |
|
super(MQMHASTP, self).__init__() |
|
self.n_query = nn.ModuleList([ |
|
MHASTP(in_dim, |
|
layer_num=layer_num, |
|
head_num=head_num, |
|
d_s=d_s, |
|
bottleneck_dim=bottleneck_dim) for i in range(query_num) |
|
]) |
|
self.query_num = query_num |
|
self.in_dim = in_dim |
|
|
|
def forward(self, input): |
|
""" |
|
input: a 3-dimensional tensor in xvector architecture |
|
or a 4-dimensional tensor in resnet architecture |
|
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) |
|
""" |
|
if len(input.shape) == 4: |
|
input = input.reshape(input.shape[0], |
|
input.shape[1] * input.shape[2], |
|
input.shape[3]) |
|
assert len(input.shape) == 3 |
|
res = [] |
|
for i, layer in enumerate(self.n_query): |
|
res.append(layer(input)) |
|
out = torch.cat(res, dim=-1) |
|
return out |
|
|
|
def get_out_dim(self): |
|
|
|
|
|
return self.in_dim * 2 * self.query_num |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
def default(val, d): |
|
return val if exists(val) else d() if callable(d) else d |
|
|
|
class AdaLayerNorm(nn.Module): |
|
""" |
|
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes |
|
|
|
Args: |
|
condition_dim (int): Dimension of the condition. |
|
embedding_dim (int): Dimension of the embeddings. |
|
""" |
|
|
|
def __init__(self, condition_dim: int, embedding_dim: int, eps: float = 1e-6): |
|
super().__init__() |
|
self.eps = eps |
|
self.dim = embedding_dim |
|
self.scale = nn.Linear(condition_dim, embedding_dim) |
|
self.shift = nn.Linear(condition_dim, embedding_dim) |
|
|
|
|
|
|
|
if self.scale.bias is not None: nn.init.zeros_(self.scale.bias) |
|
if self.shift.bias is not None: nn.init.zeros_(self.shift.bias) |
|
|
|
def forward(self, x: torch.Tensor, cond_embedding: torch.Tensor) -> torch.Tensor: |
|
scale = self.scale(cond_embedding) |
|
shift = self.shift(cond_embedding) |
|
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) |
|
x = x * scale.unsqueeze(1) + shift.unsqueeze(1) |
|
return x |
|
|
|
|
|
class ConvNeXtBlock(nn.Module): |
|
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. |
|
|
|
Args: |
|
dim (int): Number of input channels. |
|
intermediate_dim (int): Dimensionality of the intermediate layer. |
|
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. |
|
Defaults to None. |
|
condition_dim (int, optional): Dimension for AdaLayerNorm. |
|
None means non-conditional LayerNorm. Defaults to None. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
intermediate_dim: int, |
|
layer_scale_init_value: float, |
|
condition_dim: Optional[int] = None, |
|
): |
|
super().__init__() |
|
self.dwconv = nn.Conv1d( |
|
dim, dim, kernel_size=7, padding=3, groups=dim |
|
) |
|
self.adanorm = condition_dim is not None |
|
if self.adanorm: |
|
self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6) |
|
else: |
|
self.norm = nn.LayerNorm(dim, eps=1e-6) |
|
self.pwconv1 = nn.Linear( |
|
dim, intermediate_dim |
|
) |
|
self.act = nn.GELU() |
|
self.pwconv2 = nn.Linear(intermediate_dim, dim) |
|
self.gamma = ( |
|
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) |
|
if layer_scale_init_value is not None and layer_scale_init_value > 0 |
|
else None |
|
) |
|
|
|
def forward( |
|
self, x: torch.Tensor, cond_embedding: Optional[torch.Tensor] = None |
|
) -> torch.Tensor: |
|
residual = x |
|
x = self.dwconv(x) |
|
x = x.transpose(1, 2) |
|
if self.adanorm: |
|
assert cond_embedding is not None, "Conditioning embedding required for AdaLayerNorm" |
|
x = self.norm(x, cond_embedding) |
|
else: |
|
x = self.norm(x) |
|
x = self.pwconv1(x) |
|
x = self.act(x) |
|
x = self.pwconv2(x) |
|
if self.gamma is not None: |
|
x = self.gamma * x |
|
x = x.transpose(1, 2) |
|
|
|
x = residual + x |
|
return x |
|
|
|
|
|
class ResBlock1(nn.Module): |
|
""" |
|
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions, |
|
but without upsampling layers. |
|
|
|
Args: |
|
dim (int): Number of input channels. |
|
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3. |
|
dilation (tuple[int], optional): Dilation factors for the dilated convolutions. |
|
Defaults to (1, 3, 5). |
|
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function. |
|
Defaults to 0.1. |
|
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. |
|
Defaults to None. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
kernel_size: int = 3, |
|
dilation: Tuple[int, int, int] = (1, 3, 5), |
|
lrelu_slope: float = 0.1, |
|
layer_scale_init_value: Optional[float] = None, |
|
): |
|
super().__init__() |
|
self.lrelu_slope = lrelu_slope |
|
self.convs1 = nn.ModuleList( |
|
[ |
|
weight_norm( |
|
nn.Conv1d( |
|
dim, |
|
dim, |
|
kernel_size, |
|
1, |
|
dilation=dilation[0], |
|
padding=self.get_padding(kernel_size, dilation[0]), |
|
) |
|
), |
|
weight_norm( |
|
nn.Conv1d( |
|
dim, |
|
dim, |
|
kernel_size, |
|
1, |
|
dilation=dilation[1], |
|
padding=self.get_padding(kernel_size, dilation[1]), |
|
) |
|
), |
|
weight_norm( |
|
nn.Conv1d( |
|
dim, |
|
dim, |
|
kernel_size, |
|
1, |
|
dilation=dilation[2], |
|
padding=self.get_padding(kernel_size, dilation[2]), |
|
) |
|
), |
|
] |
|
) |
|
|
|
self.convs2 = nn.ModuleList( |
|
[ |
|
weight_norm( |
|
nn.Conv1d( |
|
dim, |
|
dim, |
|
kernel_size, |
|
1, |
|
dilation=1, |
|
padding=self.get_padding(kernel_size, 1), |
|
) |
|
), |
|
weight_norm( |
|
nn.Conv1d( |
|
dim, |
|
dim, |
|
kernel_size, |
|
1, |
|
dilation=1, |
|
padding=self.get_padding(kernel_size, 1), |
|
) |
|
), |
|
weight_norm( |
|
nn.Conv1d( |
|
dim, |
|
dim, |
|
kernel_size, |
|
1, |
|
dilation=1, |
|
padding=self.get_padding(kernel_size, 1), |
|
) |
|
), |
|
] |
|
) |
|
|
|
self.gamma = nn.ParameterList( |
|
[ |
|
( |
|
nn.Parameter( |
|
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True |
|
) |
|
if layer_scale_init_value is not None |
|
else None |
|
), |
|
( |
|
nn.Parameter( |
|
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True |
|
) |
|
if layer_scale_init_value is not None |
|
else None |
|
), |
|
( |
|
nn.Parameter( |
|
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True |
|
) |
|
if layer_scale_init_value is not None |
|
else None |
|
), |
|
] |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma): |
|
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope) |
|
xt = c1(xt) |
|
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope) |
|
xt = c2(xt) |
|
if gamma is not None: |
|
xt = gamma * xt |
|
x = xt + x |
|
return x |
|
|
|
def remove_weight_norm(self): |
|
for l in self.convs1: |
|
remove_weight_norm(l) |
|
for l in self.convs2: |
|
remove_weight_norm(l) |
|
|
|
@staticmethod |
|
def get_padding(kernel_size: int, dilation: int = 1) -> int: |
|
return int((kernel_size * dilation - dilation) / 2) |
|
|
|
|
|
class Backbone(nn.Module): |
|
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers.""" |
|
|
|
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: |
|
""" |
|
Args: |
|
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size, |
|
C denotes input features, and L is the sequence length. |
|
|
|
Returns: |
|
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length, |
|
and H denotes the model dimension. |
|
""" |
|
raise NotImplementedError("Subclasses must implement the forward method.") |
|
|
|
|
|
class VocosBackbone(Backbone): |
|
""" |
|
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization |
|
|
|
Args: |
|
input_channels (int): Number of input features channels. |
|
dim (int): Hidden dimension of the model. |
|
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. |
|
num_layers (int): Number of ConvNeXtBlock layers. |
|
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. |
|
condition_dim (int, optional): Dimension for AdaLayerNorm. |
|
None means non-conditional model. Defaults to None. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_channels: int, |
|
dim: int, |
|
intermediate_dim: int, |
|
num_layers: int, |
|
layer_scale_init_value: Optional[float] = None, |
|
condition_dim: Optional[int] = None, |
|
): |
|
super().__init__() |
|
self.input_channels = input_channels |
|
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3) |
|
self.adanorm = condition_dim is not None |
|
if self.adanorm: |
|
self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6) |
|
else: |
|
self.norm = nn.LayerNorm(dim, eps=1e-6) |
|
layer_scale_init_value = layer_scale_init_value or 1 / num_layers if num_layers > 0 else None |
|
self.convnext = nn.ModuleList( |
|
[ |
|
ConvNeXtBlock( |
|
dim=dim, |
|
intermediate_dim=intermediate_dim, |
|
layer_scale_init_value=layer_scale_init_value, |
|
condition_dim=condition_dim, |
|
) |
|
for _ in range(num_layers) |
|
] |
|
) |
|
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) |
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, (nn.Conv1d, nn.Linear)): |
|
nn.init.trunc_normal_(m.weight, std=0.02) |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward(self, x: torch.Tensor, condition: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
|
x = self.embed(x) |
|
|
|
x_transposed = x.transpose(1, 2) |
|
if self.adanorm: |
|
assert condition is not None |
|
norm_out = self.norm(x_transposed, condition) |
|
else: |
|
norm_out = self.norm(x_transposed) |
|
|
|
x = norm_out.transpose(1, 2) |
|
for conv_block in self.convnext: |
|
x = conv_block(x, condition) |
|
|
|
x = self.final_layer_norm(x.transpose(1, 2)) |
|
return x |
|
|
|
|
|
class VocosResNetBackbone(Backbone): |
|
""" |
|
Vocos backbone module built with ResBlocks. |
|
|
|
Args: |
|
input_channels (int): Number of input features channels. |
|
dim (int): Hidden dimension of the model. |
|
num_blocks (int): Number of ResBlock1 blocks. |
|
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_channels, |
|
dim, |
|
num_blocks, |
|
layer_scale_init_value=None, |
|
): |
|
super().__init__() |
|
self.input_channels = input_channels |
|
self.embed = weight_norm( |
|
nn.Conv1d(input_channels, dim, kernel_size=3, padding=1) |
|
) |
|
layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3 if num_blocks > 0 else None |
|
self.resnet = nn.Sequential( |
|
*[ |
|
ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) |
|
for _ in range(num_blocks) |
|
] |
|
) |
|
|
|
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: |
|
|
|
x = self.embed(x) |
|
|
|
x = self.resnet(x) |
|
|
|
x = x.transpose(1, 2) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Decoder(nn.Module): |
|
"""Decoder module with convnext and upsampling blocks |
|
|
|
Args: |
|
sample_ratios (List[int]): sample ratios |
|
example: [2, 2] means upsample by 2x and then upsample by 2x |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_channels: int, |
|
vocos_dim: int, |
|
vocos_intermediate_dim: int, |
|
vocos_num_layers: int, |
|
out_channels: int, |
|
condition_dim: int = None, |
|
sample_ratios: List[int] = [1, 1], |
|
use_tanh_at_final: bool = False, |
|
): |
|
super().__init__() |
|
|
|
self.linear_pre = nn.Linear(input_channels, vocos_dim) |
|
|
|
upsample_modules = [] |
|
current_dim = vocos_dim |
|
for i, ratio in enumerate(sample_ratios): |
|
upsample_modules.append( |
|
nn.Sequential( |
|
SamplingBlock( |
|
dim=current_dim, |
|
groups=current_dim, |
|
upsample_scale=ratio, |
|
), |
|
|
|
|
|
|
|
|
|
|
|
nn.Conv1d(current_dim, current_dim, kernel_size=3, padding=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
) |
|
|
|
|
|
|
|
self.upsample = nn.Sequential(*upsample_modules) |
|
|
|
|
|
self.vocos_backbone = VocosBackbone( |
|
input_channels=current_dim, |
|
dim=vocos_dim, |
|
intermediate_dim=vocos_intermediate_dim, |
|
num_layers=vocos_num_layers, |
|
condition_dim=condition_dim, |
|
) |
|
self.linear_post = nn.Linear(vocos_dim, out_channels) |
|
self.use_tanh_at_final = use_tanh_at_final |
|
|
|
def forward(self, x: torch.Tensor, c: torch.Tensor = None): |
|
"""decoder forward. |
|
|
|
Args: |
|
x (torch.Tensor): (batch_size, input_channels, length) |
|
c (torch.Tensor): (batch_size, condition_dim) - Optional condition |
|
|
|
Returns: |
|
x (torch.Tensor): (batch_size, out_channels, length_upsampled) |
|
""" |
|
|
|
x = self.linear_pre(x.transpose(1, 2)) |
|
x = x.transpose(1, 2) |
|
|
|
|
|
x = self.upsample(x) |
|
|
|
|
|
x = self.vocos_backbone(x, condition=c) |
|
|
|
x = self.linear_post(x) |
|
x = x.transpose(1, 2) |
|
|
|
if self.use_tanh_at_final: |
|
x = torch.tanh(x) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Encoder(nn.Module): |
|
"""Encoder module with convnext and downsampling blocks""" |
|
|
|
def __init__( |
|
self, |
|
input_channels: int, |
|
vocos_dim: int, |
|
vocos_intermediate_dim: int, |
|
vocos_num_layers: int, |
|
out_channels: int, |
|
sample_ratios: List[int] = [1, 1], |
|
): |
|
super().__init__() |
|
""" |
|
Encoder module with VocosBackbone and sampling blocks. |
|
|
|
Args: |
|
sample_ratios (List[int]): sample ratios |
|
example: [2, 2] means downsample by 2x and then downsample by 2x |
|
""" |
|
|
|
self.encoder_backbone = VocosBackbone( |
|
input_channels=input_channels, |
|
dim=vocos_dim, |
|
intermediate_dim=vocos_intermediate_dim, |
|
num_layers=vocos_num_layers, |
|
condition_dim=None, |
|
) |
|
|
|
downsample_modules = [] |
|
current_dim = vocos_dim |
|
for i, ratio in enumerate(sample_ratios): |
|
downsample_modules.append( |
|
nn.Sequential( |
|
SamplingBlock( |
|
dim=current_dim, |
|
groups=current_dim, |
|
downsample_scale=ratio, |
|
), |
|
|
|
nn.Conv1d(current_dim, current_dim, kernel_size=3, padding=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
) |
|
|
|
|
|
self.downsample = nn.Sequential(*downsample_modules) |
|
|
|
self.project = nn.Linear(current_dim, out_channels) |
|
|
|
def forward(self, x: torch.Tensor, *args): |
|
""" |
|
Args: |
|
x (torch.Tensor): (batch_size, input_channels, length) |
|
|
|
Returns: |
|
x (torch.Tensor): (batch_size, out_channels, length_downsampled) |
|
""" |
|
|
|
x = self.encoder_backbone(x) |
|
x = x.transpose(1, 2) |
|
|
|
|
|
x = self.downsample(x) |
|
|
|
x = x.transpose(1, 2) |
|
x = self.project(x) |
|
return x.transpose(1, 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DecoderBlock(nn.Module): |
|
def __init__( |
|
self, |
|
input_dim: int = 16, |
|
output_dim: int = 8, |
|
kernel_size: int = 2, |
|
stride: int = 1, |
|
): |
|
super().__init__() |
|
|
|
stride = max(1, stride) |
|
|
|
if kernel_size < stride: |
|
kernel_size = stride |
|
|
|
padding = (kernel_size - stride) // 2 |
|
output_padding = stride % 2 if kernel_size % 2 == 0 else 0 |
|
|
|
|
|
|
|
|
|
self.block = nn.Sequential( |
|
Snake1d(input_dim), |
|
WNConvTranspose1d( |
|
input_dim, |
|
output_dim, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
output_padding=output_padding, |
|
), |
|
ResidualUnit(output_dim, dilation=1), |
|
ResidualUnit(output_dim, dilation=3), |
|
ResidualUnit(output_dim, dilation=9), |
|
) |
|
|
|
def forward(self, x): |
|
return self.block(x) |
|
|
|
|
|
class WaveGenerator(nn.Module): |
|
def __init__( |
|
self, |
|
input_channel, |
|
channels, |
|
rates, |
|
kernel_sizes, |
|
d_out: int = 1, |
|
): |
|
super().__init__() |
|
|
|
|
|
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] |
|
|
|
|
|
current_channels = channels |
|
for i, (kernel_size, stride) in enumerate(zip(kernel_sizes, rates)): |
|
input_dim = current_channels |
|
|
|
output_dim = max(1, channels // (2 ** (i + 1))) |
|
layers += [DecoderBlock(input_dim, output_dim, kernel_size, stride)] |
|
current_channels = output_dim |
|
|
|
|
|
layers += [ |
|
Snake1d(current_channels), |
|
WNConv1d(current_channels, d_out, kernel_size=7, padding=3), |
|
nn.Tanh(), |
|
] |
|
|
|
self.model = nn.Sequential(*layers) |
|
|
|
self.apply(init_weights) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def round_ste(z: Tensor) -> Tensor: |
|
"""Round with straight through gradients.""" |
|
zhat = z.round() |
|
return z + (zhat - z).detach() |
|
|
|
|
|
class FSQ(nn.Module): |
|
def __init__( |
|
self, |
|
levels: List[int], |
|
dim: int | None = None, |
|
num_codebooks=1, |
|
keep_num_codebooks_dim: bool | None = None, |
|
scale: float | None = None, |
|
allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64), |
|
channel_first: bool = False, |
|
projection_has_bias: bool = True, |
|
return_indices=True, |
|
force_quantization_f32=True, |
|
): |
|
super().__init__() |
|
_levels = torch.tensor(levels, dtype=int32) |
|
self.register_buffer("_levels", _levels, persistent=False) |
|
|
|
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32) |
|
self.register_buffer("_basis", _basis, persistent=False) |
|
|
|
self.scale = scale |
|
|
|
codebook_dim = len(levels) |
|
self.codebook_dim = codebook_dim |
|
|
|
effective_codebook_dim = codebook_dim * num_codebooks |
|
self.num_codebooks = num_codebooks |
|
self.effective_codebook_dim = effective_codebook_dim |
|
|
|
|
|
|
|
if num_codebooks == 1: |
|
keep_num_codebooks_dim = False |
|
else: |
|
keep_num_codebooks_dim = default(keep_num_codebooks_dim, True) |
|
|
|
|
|
|
|
if num_codebooks > 1 and not keep_num_codebooks_dim: |
|
raise ValueError("If num_codebooks > 1, keep_num_codebooks_dim must be True or None (defaults to True).") |
|
self.keep_num_codebooks_dim = keep_num_codebooks_dim |
|
|
|
|
|
self.dim = default(dim, len(_levels) * num_codebooks) |
|
|
|
self.channel_first = channel_first |
|
|
|
has_projections = self.dim != effective_codebook_dim |
|
self.project_in = ( |
|
nn.Linear(self.dim, effective_codebook_dim, bias=projection_has_bias) |
|
if has_projections |
|
else nn.Identity() |
|
) |
|
self.project_out = ( |
|
nn.Linear(effective_codebook_dim, self.dim, bias=projection_has_bias) |
|
if has_projections |
|
else nn.Identity() |
|
) |
|
|
|
self.has_projections = has_projections |
|
|
|
self.return_indices = return_indices |
|
if return_indices: |
|
self.codebook_size = self._levels.prod().item() |
|
|
|
|
|
|
|
|
|
|
|
self.allowed_dtypes = allowed_dtypes |
|
self.force_quantization_f32 = force_quantization_f32 |
|
|
|
@property |
|
def implicit_codebook(self): |
|
|
|
device = self._levels.device |
|
indices = torch.arange(self.codebook_size, device=device) |
|
return self._indices_to_codes(indices) |
|
|
|
|
|
def bound(self, z, eps: float = 1e-3): |
|
"""Bound `z`, an array of shape (..., d).""" |
|
levels = self._levels.to(z.device) |
|
half_l = (levels - 1) * (1 + eps) / 2 |
|
offset = torch.where(levels % 2 == 0, 0.5, 0.0) |
|
shift = (offset / half_l).atanh() if torch.any(half_l != 0) else torch.zeros_like(offset) |
|
|
|
shift = shift.view(1, 1, -1) if z.ndim == 3 else shift |
|
half_l = half_l.view(1, 1, -1) if z.ndim == 3 else half_l |
|
|
|
|
|
z_clipped = torch.clamp(z, min=-1.0 + eps, max=1.0 - eps) |
|
|
|
|
|
|
|
|
|
|
|
upper_bound = (levels - 1) / 2 |
|
lower_bound = -upper_bound |
|
upper_bound = upper_bound.view(1, 1, -1) if z.ndim == 3 else upper_bound |
|
lower_bound = lower_bound.view(1, 1, -1) if z.ndim == 3 else lower_bound |
|
|
|
return torch.clamp(z, min=lower_bound, max=upper_bound) |
|
|
|
|
|
def quantize(self, z): |
|
"""Quantizes z, returns quantized zhat, same shape as z.""" |
|
quantized = round_ste(self.bound(z)) |
|
levels = self._levels.to(z.device) |
|
half_width = levels // 2 |
|
|
|
half_width = torch.where(half_width == 0, torch.tensor(1.0, device=z.device), half_width.float()) |
|
half_width_view = half_width.view(1, 1, -1) if quantized.ndim == 3 else half_width |
|
return quantized / half_width_view |
|
|
|
def _scale_and_shift(self, zhat_normalized): |
|
levels = self._levels.to(zhat_normalized.device) |
|
half_width = levels // 2 |
|
half_width_view = half_width.view(1, 1, -1) if zhat_normalized.ndim == 3 else half_width |
|
return (zhat_normalized * half_width_view) + half_width_view |
|
|
|
def _scale_and_shift_inverse(self, zhat): |
|
levels = self._levels.to(zhat.device) |
|
half_width = levels // 2 |
|
|
|
half_width = torch.where(half_width == 0, torch.tensor(1.0, device=zhat.device), half_width.float()) |
|
half_width_view = half_width.view(1, 1, -1) if zhat.ndim == 3 else half_width |
|
return (zhat - half_width_view) / half_width_view |
|
|
|
def _indices_to_codes(self, indices): |
|
level_indices = self.indices_to_level_indices(indices) |
|
codes = self._scale_and_shift_inverse(level_indices.float()) |
|
return codes |
|
|
|
def codes_to_indices(self, zhat): |
|
"""Converts a `code` to an index in the codebook.""" |
|
assert zhat.shape[-1] == self.codebook_dim |
|
zhat_scaled = self._scale_and_shift(zhat) |
|
|
|
basis = self._basis.to(zhat.device, dtype=int32) |
|
basis_view = basis.view(1, 1, -1) if zhat_scaled.ndim == 3 else basis |
|
|
|
product = (zhat_scaled * basis_view).round().int() |
|
return product.sum(dim=-1).to(int32) |
|
|
|
def indices_to_level_indices(self, indices): |
|
"""Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings""" |
|
indices_reshaped = rearrange(indices, "... -> ... 1") |
|
basis = self._basis.to(indices.device) |
|
levels = self._levels.to(indices.device) |
|
|
|
basis_view = basis.view(*([1] * (indices_reshaped.ndim - 1)), -1) |
|
levels_view = levels.view(*([1] * (indices_reshaped.ndim - 1)), -1) |
|
|
|
codes_non_centered = (indices_reshaped // basis_view) % levels_view |
|
return codes_non_centered |
|
|
|
|
|
|
|
def forward(self, z): |
|
""" |
|
einstein notation |
|
b - batch |
|
... - sequence, spatial dimensions |
|
d - feature dimension |
|
c - number of codebook dim (within a single quantizer) |
|
g - number of quantizers (groups) - handled by ResidualFSQ/GroupedResidualFSQ |
|
""" |
|
|
|
|
|
|
|
|
|
if self.channel_first: |
|
|
|
if z.ndim > 2: |
|
z = rearrange(z, "b d ... -> b ... d") |
|
z, ps = pack([z], "b * d") |
|
|
|
else: |
|
|
|
if z.ndim > 2: |
|
z, ps = pack([z], "b * d") |
|
|
|
|
|
|
|
assert ( |
|
z.shape[-1] == self.dim |
|
), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" |
|
|
|
|
|
z_projected = self.project_in(z) |
|
|
|
|
|
if self.num_codebooks > 1: |
|
z_reshaped = rearrange(z_projected, "b ... (c d) -> b ... c d", c=self.num_codebooks) |
|
else: |
|
|
|
z_reshaped = rearrange(z_projected, "b ... d -> b ... 1 d") |
|
|
|
|
|
force_f32 = self.force_quantization_f32 |
|
quantization_context = ( |
|
partial(autocast, "cuda", enabled=False) if force_f32 else nullcontext |
|
) |
|
|
|
codes = None |
|
indices = None |
|
|
|
with quantization_context(): |
|
orig_dtype = z_reshaped.dtype |
|
|
|
if force_f32 and orig_dtype not in self.allowed_dtypes: |
|
z_for_quant = z_reshaped.float() |
|
else: |
|
z_for_quant = z_reshaped |
|
|
|
codes = self.quantize(z_for_quant) |
|
|
|
if self.return_indices: |
|
indices = self.codes_to_indices(codes) |
|
|
|
|
|
codes = codes.type(orig_dtype) |
|
|
|
|
|
|
|
if self.num_codebooks > 1: |
|
codes_reshaped = rearrange(codes, "b ... c d -> b ... (c d)") |
|
else: |
|
codes_reshaped = rearrange(codes, "b ... 1 d -> b ... d") |
|
|
|
out = self.project_out(codes_reshaped) |
|
|
|
|
|
if z.ndim > 2: |
|
out = unpack(out, ps, "b * d")[0] |
|
if self.return_indices: |
|
indices = unpack(indices, ps, "b * c")[0] |
|
|
|
|
|
if self.channel_first and out.ndim > 2: |
|
out = rearrange(out, "b ... d -> b d ...") |
|
if self.return_indices and indices.ndim > 1: |
|
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
if self.return_indices and self.num_codebooks == 1 and not self.keep_num_codebooks_dim: |
|
indices = indices.squeeze(-1) |
|
|
|
|
|
return out, indices |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_distributed(): |
|
return dist.is_initialized() and dist.get_world_size() > 1 |
|
|
|
def get_maybe_sync_seed(device, max_size=10_000): |
|
rand_int = torch.randint(0, max_size, (), device=device) |
|
if is_distributed(): |
|
|
|
if rand_int.device != device: |
|
rand_int = rand_int.to(device) |
|
dist.all_reduce(rand_int) |
|
return rand_int.item() |
|
|
|
def round_up_multiple(num, mult): |
|
|
|
if mult <= 0: |
|
return num |
|
|
|
return (num + mult - 1) // mult * mult |
|
|
|
|
|
class ResidualFSQ(nn.Module): |
|
"""Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf""" |
|
|
|
def __init__( |
|
self, |
|
*, |
|
levels: List[int], |
|
num_quantizers, |
|
dim=None, |
|
|
|
quantize_dropout=False, |
|
quantize_dropout_cutoff_index=0, |
|
quantize_dropout_multiple_of=1, |
|
channel_first: bool = False, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
codebook_dim = len(levels) |
|
dim = default(dim, codebook_dim) |
|
|
|
requires_projection = codebook_dim != dim |
|
self.project_in = ( |
|
nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity() |
|
) |
|
self.project_out = ( |
|
nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity() |
|
) |
|
self.has_projections = requires_projection |
|
|
|
self.channel_first = channel_first |
|
self.num_quantizers = num_quantizers |
|
|
|
self.levels = levels |
|
self.layers = nn.ModuleList([]) |
|
|
|
levels_tensor = torch.Tensor(levels) |
|
|
|
scales = [] |
|
|
|
for ind in range(num_quantizers): |
|
|
|
|
|
|
|
|
|
|
|
scale_value = 1.0 |
|
scales.append(scale_value) |
|
|
|
|
|
fsq = FSQ(levels=levels, dim=codebook_dim, channel_first=channel_first, **kwargs) |
|
|
|
self.layers.append(fsq) |
|
|
|
|
|
assert all([not fsq.has_projections for fsq in self.layers]), "FSQ layers within ResidualFSQ should not have internal projections." |
|
|
|
self.codebook_size = self.layers[0].codebook_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.quantize_dropout = quantize_dropout and num_quantizers > 1 |
|
|
|
assert quantize_dropout_cutoff_index >= 0 |
|
|
|
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index |
|
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of |
|
|
|
@property |
|
def codebooks(self): |
|
|
|
codebooks = [layer.implicit_codebook for layer in self.layers] |
|
codebooks = torch.stack(codebooks, dim=0) |
|
return codebooks |
|
|
|
def get_codes_from_indices(self, indices): |
|
|
|
num_dims = indices.ndim |
|
q_dim = -1 |
|
|
|
|
|
for i in range(num_dims): |
|
if indices.shape[i] == self.num_quantizers: |
|
q_dim = i |
|
break |
|
if q_dim == -1 and self.num_quantizers == 1 and indices.shape[-1] != 1: |
|
|
|
indices = indices.unsqueeze(-1) |
|
q_dim = -1 |
|
elif q_dim == -1: |
|
raise ValueError(f"Could not find quantizer dimension ({self.num_quantizers}) in indices shape {indices.shape}") |
|
|
|
|
|
if q_dim != num_dims - 1: |
|
permute_dims = list(range(num_dims)) |
|
permute_dims.pop(q_dim) |
|
permute_dims.append(q_dim) |
|
indices = indices.permute(*permute_dims) |
|
|
|
|
|
batch_shape = indices.shape[:-1] |
|
indices = indices.reshape(-1, self.num_quantizers) |
|
|
|
|
|
if indices.max() >= self.codebook_size: |
|
raise ValueError(f"Invalid index found in indices: {indices.max()}. Max allowed is {self.codebook_size - 1}.") |
|
if indices.min() < -1: |
|
raise ValueError(f"Invalid index found in indices: {indices.min()}. Min allowed is -1 (dropout).") |
|
|
|
mask = indices == -1 |
|
effective_indices = indices.masked_fill(mask, 0) |
|
|
|
all_codes = [] |
|
|
|
for i in range(self.num_quantizers): |
|
layer_indices = effective_indices[:, i] |
|
|
|
|
|
|
|
layer_codes = self.layers[i].indices_to_codes(layer_indices) |
|
all_codes.append(layer_codes) |
|
|
|
all_codes_tensor = torch.stack(all_codes, dim=0) |
|
|
|
|
|
mask_expanded = mask.permute(1, 0).unsqueeze(-1) |
|
all_codes_tensor = all_codes_tensor.masked_fill(mask_expanded, 0.0) |
|
|
|
|
|
all_codes_tensor = all_codes_tensor.reshape(self.num_quantizers, *batch_shape, -1) |
|
|
|
|
|
if q_dim != num_dims - 1: |
|
|
|
inv_permute_dims = list(range(num_dims)) |
|
inv_permute_dims.insert(q_dim, num_dims) |
|
inv_permute_dims.pop() |
|
|
|
|
|
|
|
|
|
|
|
pass |
|
|
|
return all_codes_tensor |
|
|
|
|
|
def get_output_from_indices(self, indices): |
|
|
|
codes = self.get_codes_from_indices(indices) |
|
codes_summed = reduce(codes, "q b ... d -> b ... d", "sum") |
|
|
|
output = self.project_out(codes_summed) |
|
|
|
|
|
if self.channel_first and output.ndim > 2: |
|
|
|
output = rearrange(output, "b ... d -> b d ...") |
|
|
|
return output |
|
|
|
def forward(self, x, return_all_codes=False, rand_quantize_dropout_fixed_seed=None): |
|
num_quant, quant_dropout_multiple_of, device = ( |
|
self.num_quantizers, |
|
self.quantize_dropout_multiple_of, |
|
x.device, |
|
) |
|
|
|
|
|
original_shape = x.shape |
|
if self.channel_first: |
|
if x.ndim > 2: |
|
x = rearrange(x, "b d ... -> b ... d") |
|
x, ps = pack([x], "b * d") |
|
|
|
else: |
|
|
|
if x.ndim > 2: |
|
x, ps = pack([x], "b * d") |
|
|
|
|
|
|
|
|
|
projected_x = self.project_in(x) |
|
|
|
quantized_out = 0.0 |
|
residual = projected_x |
|
|
|
all_indices = [] |
|
|
|
should_quantize_dropout = self.training and self.quantize_dropout |
|
|
|
|
|
|
|
rand_quantize_dropout_index = num_quant |
|
|
|
if should_quantize_dropout: |
|
if not exists(rand_quantize_dropout_fixed_seed): |
|
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device) |
|
|
|
rand = random.Random(rand_quantize_dropout_fixed_seed) |
|
|
|
valid_cutoff = max(0, self.quantize_dropout_cutoff_index) |
|
rand_quantize_dropout_index = rand.randrange(valid_cutoff, num_quant) |
|
|
|
if quant_dropout_multiple_of != 1: |
|
rand_quantize_dropout_index = ( |
|
round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1 |
|
) |
|
|
|
rand_quantize_dropout_index = min(rand_quantize_dropout_index, num_quant - 1) |
|
|
|
|
|
null_indices_shape = list(x.shape[:-1]) |
|
null_indices = torch.full(null_indices_shape, -1, device=device, dtype=torch.long) |
|
|
|
|
|
|
|
|
|
|
|
|
|
for quantizer_index, layer in enumerate(self.layers): |
|
|
|
|
|
if quantizer_index > rand_quantize_dropout_index: |
|
|
|
|
|
|
|
all_indices.append(null_indices) |
|
continue |
|
|
|
|
|
|
|
|
|
quantized, indices = layer(residual) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
quantized_detached = quantized.detach() |
|
residual = residual - quantized_detached |
|
quantized_out = quantized_out + quantized |
|
|
|
|
|
if indices is None: |
|
raise ValueError(f"FSQ layer {quantizer_index} did not return indices.") |
|
all_indices.append(indices) |
|
|
|
|
|
final_quantized_out = self.project_out(quantized_out) |
|
|
|
|
|
all_indices = torch.stack(all_indices, dim=-1) |
|
|
|
|
|
if x.ndim > 2: |
|
final_quantized_out = unpack(final_quantized_out, ps, "b * d")[0] |
|
all_indices = unpack(all_indices, ps, "b * q")[0] |
|
|
|
|
|
if self.channel_first and final_quantized_out.ndim > 2: |
|
final_quantized_out = rearrange(final_quantized_out, "b ... d -> b d ...") |
|
|
|
|
|
|
|
|
|
|
|
ret = (final_quantized_out, all_indices) |
|
|
|
if not return_all_codes: |
|
return ret |
|
|
|
|
|
|
|
all_codes = self.get_codes_from_indices(all_indices) |
|
|
|
|
|
|
|
if self.channel_first and all_codes.ndim > 3: |
|
all_codes = rearrange(all_codes, "q b ... d -> q b d ...") |
|
|
|
return (*ret, all_codes) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Res2Conv1dReluBn(nn.Module): |
|
""" |
|
in_channels == out_channels == channels |
|
""" |
|
|
|
def __init__( |
|
self, |
|
channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
dilation=1, |
|
bias=True, |
|
scale=4, |
|
): |
|
super().__init__() |
|
assert channels % scale == 0, "{} % {} != 0".format(channels, scale) |
|
self.scale = scale |
|
self.width = channels // scale |
|
self.nums = scale if scale == 1 else scale - 1 |
|
|
|
self.convs = [] |
|
self.bns = [] |
|
for i in range(self.nums): |
|
self.convs.append( |
|
nn.Conv1d( |
|
self.width, |
|
self.width, |
|
kernel_size, |
|
stride, |
|
padding, |
|
dilation, |
|
bias=bias, |
|
) |
|
) |
|
self.bns.append(nn.BatchNorm1d(self.width)) |
|
self.convs = nn.ModuleList(self.convs) |
|
self.bns = nn.ModuleList(self.bns) |
|
|
|
def forward(self, x): |
|
out = [] |
|
spx = torch.split(x, self.width, 1) |
|
sp = spx[0] |
|
|
|
for i, (conv, bn) in enumerate(zip(self.convs, self.bns)): |
|
|
|
if i >= 1: |
|
sp = sp + spx[i] |
|
sp = conv(sp) |
|
sp = bn(F.relu(sp)) |
|
out.append(sp) |
|
if self.scale != 1: |
|
|
|
out.append(spx[self.nums]) |
|
out = torch.cat(out, dim=1) |
|
|
|
return out |
|
|
|
|
|
""" Conv1d + BatchNorm1d + ReLU """ |
|
class Conv1dReluBn(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
dilation=1, |
|
bias=True, |
|
): |
|
super().__init__() |
|
self.conv = nn.Conv1d( |
|
in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias |
|
) |
|
self.bn = nn.BatchNorm1d(out_channels) |
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
return self.bn(F.relu(self.conv(x))) |
|
|
|
|
|
""" The SE connection of 1D case. """ |
|
class SE_Connect(nn.Module): |
|
def __init__(self, channels, se_bottleneck_dim=128): |
|
super().__init__() |
|
self.linear1 = nn.Linear(channels, se_bottleneck_dim) |
|
self.linear2 = nn.Linear(se_bottleneck_dim, channels) |
|
|
|
def forward(self, x): |
|
|
|
out = x.mean(dim=2) |
|
out = F.relu(self.linear1(out)) |
|
out = torch.sigmoid(self.linear2(out)) |
|
out = x * out.unsqueeze(2) |
|
|
|
return out |
|
|
|
|
|
""" SE-Res2Block of the ECAPA-TDNN architecture. """ |
|
class SE_Res2Block(nn.Module): |
|
def __init__(self, channels, kernel_size, stride, padding, dilation, scale): |
|
super().__init__() |
|
self.se_res2block = nn.Sequential( |
|
Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0), |
|
Res2Conv1dReluBn( |
|
channels, kernel_size, stride, padding, dilation, scale=scale |
|
), |
|
Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0), |
|
SE_Connect(channels), |
|
) |
|
|
|
def forward(self, x): |
|
return x + self.se_res2block(x) |
|
|
|
|
|
class ECAPA_TDNN(nn.Module): |
|
def __init__( |
|
self, |
|
channels=512, |
|
feat_dim=80, |
|
embed_dim=192, |
|
pooling_func="ASTP", |
|
global_context_att=False, |
|
emb_bn=False, |
|
): |
|
super().__init__() |
|
|
|
self.layer1 = Conv1dReluBn(feat_dim, channels, kernel_size=5, padding=2) |
|
self.layer2 = SE_Res2Block( |
|
channels, kernel_size=3, stride=1, padding=2, dilation=2, scale=8 |
|
) |
|
self.layer3 = SE_Res2Block( |
|
channels, kernel_size=3, stride=1, padding=3, dilation=3, scale=8 |
|
) |
|
self.layer4 = SE_Res2Block( |
|
channels, kernel_size=3, stride=1, padding=4, dilation=4, scale=8 |
|
) |
|
|
|
cat_channels = channels * 3 |
|
|
|
|
|
self.conv = nn.Conv1d(cat_channels, cat_channels, kernel_size=1) |
|
|
|
|
|
if pooling_func == "TAP": pooling_layer = TAP |
|
elif pooling_func == "TSDP": pooling_layer = TSDP |
|
elif pooling_func == "TSTP": pooling_layer = TSTP |
|
elif pooling_func == "ASTP": pooling_layer = ASTP |
|
elif pooling_func == "MHASTP": pooling_layer = MHASTP |
|
elif pooling_func == "MQMHASTP": pooling_layer = MQMHASTP |
|
else: raise ValueError(f"Unsupported pooling function: {pooling_func}") |
|
|
|
self.pool = pooling_layer( |
|
in_dim=cat_channels, |
|
global_context_att=global_context_att |
|
|
|
) |
|
|
|
|
|
|
|
if hasattr(self.pool, 'get_out_dim'): |
|
self.pool_out_dim = self.pool.get_out_dim() |
|
elif isinstance(self.pool, (TSTP, ASTP, MHASTP, MQMHASTP)): |
|
|
|
self.pool_out_dim = cat_channels * (2 * getattr(self.pool, 'query_num', 1) if isinstance(self.pool, MQMHASTP) else 2) |
|
else: |
|
self.pool_out_dim = cat_channels |
|
|
|
self.bn = nn.BatchNorm1d(self.pool_out_dim) |
|
self.linear = nn.Linear(self.pool_out_dim, embed_dim) |
|
self.emb_bn = emb_bn |
|
if emb_bn: |
|
self.bn2 = nn.BatchNorm1d(embed_dim) |
|
else: |
|
self.bn2 = nn.Identity() |
|
|
|
def forward(self, x, return_latent=False): |
|
|
|
x = x.permute(0, 2, 1) |
|
|
|
out1 = self.layer1(x) |
|
out2 = self.layer2(out1) |
|
out3 = self.layer3(out2) |
|
out4 = self.layer4(out3) |
|
|
|
|
|
out = torch.cat([out2, out3, out4], dim=1) |
|
latent = F.relu(self.conv(out)) |
|
|
|
|
|
pooled_out = self.pool(latent) |
|
bn_out = self.bn(pooled_out) |
|
embedding = self.linear(bn_out) |
|
|
|
if self.emb_bn: |
|
embedding = self.bn2(embedding) |
|
|
|
if return_latent: |
|
|
|
return embedding, latent |
|
return embedding |
|
|
|
|
|
|
|
def ECAPA_TDNN_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False): |
|
return ECAPA_TDNN( |
|
channels=1024, |
|
feat_dim=feat_dim, |
|
embed_dim=embed_dim, |
|
pooling_func=pooling_func, |
|
emb_bn=emb_bn, |
|
) |
|
|
|
def ECAPA_TDNN_GLOB_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False): |
|
return ECAPA_TDNN( |
|
channels=1024, |
|
feat_dim=feat_dim, |
|
embed_dim=embed_dim, |
|
pooling_func=pooling_func, |
|
global_context_att=True, |
|
emb_bn=emb_bn, |
|
) |
|
|
|
def ECAPA_TDNN_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False): |
|
return ECAPA_TDNN( |
|
channels=512, |
|
feat_dim=feat_dim, |
|
embed_dim=embed_dim, |
|
pooling_func=pooling_func, |
|
emb_bn=emb_bn, |
|
) |
|
|
|
def ECAPA_TDNN_GLOB_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False): |
|
return ECAPA_TDNN( |
|
channels=512, |
|
feat_dim=feat_dim, |
|
embed_dim=embed_dim, |
|
pooling_func=pooling_func, |
|
global_context_att=True, |
|
emb_bn=emb_bn, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
def once(fn): |
|
called = False |
|
@wraps(fn) |
|
def inner(x): |
|
nonlocal called |
|
if called: return |
|
called = True |
|
return fn(x) |
|
return inner |
|
|
|
print_once = once(print) |
|
|
|
class Attend(nn.Module): |
|
def __init__(self, dropout=0.0, causal=False, use_flash=False): |
|
super().__init__() |
|
self.dropout = dropout |
|
self.attn_dropout = nn.Dropout(dropout) |
|
|
|
self.causal = causal |
|
self.register_buffer("mask", None, persistent=False) |
|
|
|
self.use_flash = use_flash |
|
can_use_flash = hasattr(F, 'scaled_dot_product_attention') and use_flash |
|
if can_use_flash: |
|
print_once("Using Flash Attention for Perceiver.") |
|
else: |
|
if use_flash: print_once("Flash Attention requested but not available/enabled.") |
|
self.use_flash = False |
|
|
|
|
|
self.efficient_config = namedtuple("EfficientAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]) |
|
|
|
self.cpu_config = self.efficient_config(True, True, True) |
|
self.cuda_config = self.efficient_config(True, True, True) |
|
|
|
|
|
def get_mask(self, n, device): |
|
if exists(self.mask) and self.mask.shape[-1] >= n and self.mask.device == device: |
|
return self.mask[:n, :n] |
|
|
|
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) |
|
self.register_buffer("mask", mask, persistent=False) |
|
return mask |
|
|
|
def flash_attn(self, q, k, v, mask=None): |
|
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda |
|
|
|
|
|
if k.ndim == 3: |
|
|
|
pass |
|
if v.ndim == 3: |
|
pass |
|
|
|
|
|
flash_mask = None |
|
if exists(mask): |
|
|
|
|
|
|
|
if mask.ndim == 2: |
|
flash_mask = rearrange(mask, "b j -> b 1 1 j") |
|
|
|
|
|
flash_mask = flash_mask.expand(-1, heads, q_len, -1) |
|
|
|
flash_mask = ~flash_mask |
|
elif mask.ndim == 4 and mask.shape[1] == 1: |
|
flash_mask = mask.expand(-1, heads, q_len, -1) |
|
flash_mask = ~flash_mask |
|
else: |
|
|
|
flash_mask = ~mask |
|
|
|
|
|
|
|
out = F.scaled_dot_product_attention( |
|
q, |
|
k, |
|
v, |
|
attn_mask=flash_mask if exists(flash_mask) else None, |
|
dropout_p=self.dropout if self.training else 0.0, |
|
is_causal=self.causal |
|
) |
|
return out |
|
|
|
def forward(self, q, k, v, mask=None): |
|
""" |
|
einstein notation |
|
b - batch |
|
h - heads |
|
n, i, j - sequence length (query, key/value) |
|
d - feature dimension (d_head) |
|
""" |
|
n, device = q.shape[-2], q.device |
|
scale = q.shape[-1] ** -0.5 |
|
|
|
if self.use_flash: |
|
return self.flash_attn(q, k, v, mask=mask) |
|
|
|
|
|
kv_einsum_eq = "b h j d" |
|
|
|
|
|
sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale |
|
|
|
|
|
if exists(mask): |
|
|
|
mask_value = -torch.finfo(sim.dtype).max |
|
mask = rearrange(mask, "b j -> b 1 1 j") |
|
sim = sim.masked_fill(~mask, mask_value) |
|
|
|
|
|
if self.causal: |
|
causal_mask = self.get_mask(n, device) |
|
sim = sim.masked_fill(causal_mask, mask_value) |
|
|
|
|
|
attn = sim.softmax(dim=-1) |
|
attn = self.attn_dropout(attn) |
|
|
|
|
|
out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v) |
|
|
|
return out |
|
|
|
|
|
|
|
def Sequential(*mods): |
|
return nn.Sequential(*filter(exists, mods)) |
|
|
|
class RMSNorm(nn.Module): |
|
def __init__(self, dim, scale=True, dim_cond=None): |
|
super().__init__() |
|
self.cond = exists(dim_cond) |
|
|
|
|
|
|
|
self.scale = dim**0.5 |
|
self.gamma = nn.Parameter(torch.ones(dim)) if scale else None |
|
|
|
def forward(self, x, cond=None): |
|
gamma = default(self.gamma, torch.tensor(1.0, device=x.device)) |
|
|
|
normed_x = F.normalize(x, dim=-1) |
|
return normed_x * self.scale * gamma |
|
|
|
|
|
class CausalConv1d(nn.Conv1d): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
kernel_size = self.kernel_size[0] |
|
dilation = self.dilation[0] |
|
stride = self.stride[0] |
|
assert stride == 1 |
|
self.causal_padding = dilation * (kernel_size - 1) |
|
|
|
def forward(self, x): |
|
|
|
causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0) |
|
return super().forward(causal_padded_x) |
|
|
|
class GEGLU(nn.Module): |
|
def forward(self, x): |
|
x, gate = x.chunk(2, dim=-1) |
|
return F.gelu(gate) * x |
|
|
|
def FeedForward(dim, mult=4, causal_conv=False): |
|
dim_inner = int(dim * mult * 2 / 3) |
|
|
|
conv = None |
|
if causal_conv: |
|
conv = nn.Sequential( |
|
Rearrange("b n d -> b d n"), |
|
CausalConv1d(dim_inner, dim_inner, 3), |
|
Rearrange("b d n -> b n d"), |
|
) |
|
|
|
return Sequential( |
|
nn.Linear(dim, dim_inner * 2, bias=False), |
|
GEGLU(), |
|
conv, |
|
nn.Linear(dim_inner, dim, bias=False) |
|
) |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
*, |
|
dim_context=None, |
|
causal=False, |
|
dim_head=64, |
|
heads=8, |
|
dropout=0.0, |
|
use_flash=False, |
|
cross_attn_include_queries=False, |
|
): |
|
super().__init__() |
|
|
|
self.heads = heads |
|
self.cross_attn_include_queries = cross_attn_include_queries |
|
|
|
dim_inner = dim_head * heads |
|
dim_context = default(dim_context, dim) |
|
|
|
self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash) |
|
self.to_q = nn.Linear(dim, dim_inner, bias=False) |
|
|
|
self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False) |
|
self.to_out = nn.Linear(dim_inner, dim, bias=False) |
|
|
|
def forward(self, x, context=None, mask=None): |
|
h, has_context = self.heads, exists(context) |
|
|
|
|
|
|
|
context = default(context, x) |
|
|
|
if has_context and self.cross_attn_include_queries: |
|
|
|
context = torch.cat((x, context), dim=-2) |
|
|
|
|
|
q = self.to_q(x) |
|
k, v = self.to_kv(context).chunk(2, dim=-1) |
|
|
|
|
|
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) |
|
|
|
|
|
out = self.attend(q, k, v, mask=mask) |
|
|
|
|
|
out = rearrange(out, "b h n d -> b n (h d)") |
|
return self.to_out(out) |
|
|
|
|
|
class PerceiverResampler(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
dim, |
|
depth=2, |
|
dim_context=None, |
|
num_latents=32, |
|
dim_head=64, |
|
heads=8, |
|
ff_mult=4, |
|
use_flash_attn=False, |
|
): |
|
super().__init__() |
|
dim_context = default(dim_context, dim) |
|
|
|
|
|
self.proj_context = ( |
|
nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity() |
|
) |
|
|
|
|
|
self.latents = nn.Parameter(torch.randn(num_latents, dim)) |
|
nn.init.normal_(self.latents, std=0.02) |
|
|
|
self.layers = nn.ModuleList([]) |
|
for _ in range(depth): |
|
self.layers.append( |
|
nn.ModuleList( |
|
[ |
|
|
|
Attention( |
|
dim=dim, |
|
dim_context=dim, |
|
dim_head=dim_head, |
|
heads=heads, |
|
use_flash=use_flash_attn, |
|
cross_attn_include_queries=False, |
|
), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FeedForward(dim=dim, mult=ff_mult), |
|
] |
|
) |
|
) |
|
|
|
|
|
|
|
|
|
self.layers[-1].insert(1, RMSNorm(dim)) |
|
self.layers[-1].append(RMSNorm(dim)) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
|
batch = x.shape[0] |
|
|
|
|
|
x = self.proj_context(x) |
|
|
|
|
|
latents = repeat(self.latents, "n d -> b n d", b=batch) |
|
|
|
|
|
|
|
for attn, norm1, ff, norm2 in self.layers: |
|
|
|
latents_attn = attn(latents, x, mask=mask) |
|
latents = norm1(latents_attn + latents) |
|
|
|
|
|
latents_ff = ff(latents) |
|
latents = norm2(latents_ff + latents) |
|
|
|
|
|
return latents |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpeakerEncoder(nn.Module): |
|
""" |
|
Speaker Encoder using ECAPA-TDNN, Perceiver Resampler, and Residual FSQ. |
|
|
|
Args: |
|
input_dim (int): acoustic feature dimension (e.g., mel bins) |
|
out_dim (int): output dimension of the final d-vector |
|
latent_dim (int): latent dimension for perceiver and quantization |
|
token_num (int): number of latent tokens from perceiver |
|
fsq_levels (List[int]): levels for finite scalar quantization |
|
fsq_num_quantizers (int): number of residual quantizers in FSQ |
|
ecapa_embed_dim (int): embedding dimension from ECAPA-TDNN (before projection) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_dim: int = 80, |
|
out_dim: int = 1024, |
|
latent_dim: int = 128, |
|
token_num: int = 32, |
|
fsq_levels: List[int] = [4, 4, 4, 4, 4, 4], |
|
fsq_num_quantizers: int = 1, |
|
|
|
ecapa_channels: int = 512, |
|
ecapa_embed_dim: int = 192, |
|
): |
|
super(SpeakerEncoder, self).__init__() |
|
|
|
|
|
|
|
self.speaker_encoder_base = ECAPA_TDNN_GLOB_c512( |
|
feat_dim=input_dim, |
|
embed_dim=ecapa_embed_dim |
|
) |
|
|
|
ecapa_feature_dim = ecapa_channels * 3 |
|
|
|
|
|
self.perceiver_sampler = PerceiverResampler( |
|
dim=latent_dim, |
|
dim_context=ecapa_feature_dim, |
|
num_latents=token_num, |
|
depth=2, |
|
dim_head=64, heads=8, ff_mult=4, |
|
use_flash_attn=True |
|
) |
|
|
|
|
|
self.quantizer = ResidualFSQ( |
|
levels=fsq_levels, |
|
num_quantizers=fsq_num_quantizers, |
|
dim=latent_dim, |
|
channel_first=False, |
|
quantize_dropout=False, |
|
) |
|
|
|
|
|
self.project = nn.Linear(latent_dim * token_num, out_dim) |
|
|
|
def get_codes_from_indices(self, indices: torch.Tensor) -> torch.Tensor: |
|
"""Reconstruct quantized vectors from indices.""" |
|
|
|
|
|
|
|
|
|
zq = self.quantizer.get_output_from_indices(indices) |
|
|
|
return zq |
|
|
|
def get_indices(self, mels: torch.Tensor) -> torch.Tensor: |
|
"""Get FSQ indices directly from mel spectrograms.""" |
|
|
|
_, features = self.speaker_encoder_base(mels, return_latent=True) |
|
x = self.perceiver_sampler(features.transpose(1, 2)) |
|
_, indices = self.quantizer(x) |
|
return indices |
|
|
|
def forward(self, mels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Args: |
|
mels: (B, T_mel, D_mel) - Mel spectrogram input |
|
|
|
Return: |
|
x_vector: (B, ecapa_embed_dim) - Global speaker embedding from ECAPA |
|
d_vector: (B, out_dim) - Speaker embedding derived from quantized tokens |
|
""" |
|
|
|
x_vector, features = self.speaker_encoder_base(mels, return_latent=True) |
|
|
|
|
|
|
|
|
|
perceiver_latents = self.perceiver_sampler(features.transpose(1, 2)) |
|
|
|
|
|
|
|
|
|
zq, indices = self.quantizer(perceiver_latents) |
|
|
|
|
|
|
|
zq_flat = rearrange(zq, 'b t d -> b (t d)') |
|
d_vector = self.project(zq_flat) |
|
|
|
return x_vector, d_vector |
|
|
|
def tokenize(self, mels: torch.Tensor) -> torch.Tensor: |
|
"""Tokenize the input mel spectrogram to get FSQ indices.""" |
|
|
|
_, features = self.speaker_encoder_base(mels, return_latent=True) |
|
x = self.perceiver_sampler(features.transpose(1, 2)) |
|
_, indices = self.quantizer(x) |
|
return indices |
|
|
|
def detokenize(self, indices: torch.Tensor) -> torch.Tensor: |
|
"""Detokenize FSQ indices to get the final d-vector.""" |
|
|
|
|
|
zq = self.get_codes_from_indices(indices) |
|
|
|
|
|
zq_flat = rearrange(zq, 'b t d -> b (t d)') |
|
d_vector = self.project(zq_flat) |
|
return d_vector |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ema_inplace(moving_avg, new, decay): |
|
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) |
|
|
|
|
|
class FactorizedVectorQuantize(nn.Module): |
|
def __init__( |
|
self, |
|
input_dim: int, |
|
codebook_size: int, |
|
codebook_dim: int, |
|
commitment: float, |
|
codebook_loss_weight: float = 1.0, |
|
decay: float = 0.99, |
|
threshold_ema_dead_code: float = 2.0, |
|
momentum: float = 0.99, |
|
use_l2_normlize: bool = True, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
self.input_dim = input_dim |
|
self.codebook_size = codebook_size |
|
self.codebook_dim = codebook_dim |
|
self.commitment = commitment |
|
self.codebook_loss_weight = codebook_loss_weight |
|
self.decay = decay |
|
self.threshold_ema_dead_code = threshold_ema_dead_code |
|
|
|
self.use_l2_normlize = use_l2_normlize |
|
|
|
if input_dim != self.codebook_dim: |
|
self.in_project = WNConv1d(input_dim, self.codebook_dim, kernel_size=1) |
|
self.out_project = WNConv1d(self.codebook_dim, input_dim, kernel_size=1) |
|
else: |
|
self.in_project = nn.Identity() |
|
self.out_project = nn.Identity() |
|
|
|
|
|
self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim) |
|
|
|
|
|
|
|
self.register_buffer("cluster_size", torch.zeros(self.codebook_size)) |
|
|
|
|
|
|
|
def forward(self, z: torch.Tensor) -> Dict[str, Any]: |
|
"""Quantizes the input tensor using a fixed codebook and returns |
|
the corresponding codebook vectors and losses. |
|
|
|
Parameters |
|
---------- |
|
z : Tensor[B x D_in x T] |
|
|
|
Returns |
|
------- |
|
Dict containing: |
|
z_q (Tensor[B x D_in x T]): Quantized continuous representation (passed through out_project) |
|
indices (Tensor[B x T]): Codebook indices |
|
vq_loss (Tensor[1]): Combined VQ loss (codebook + commitment) |
|
perplexity (Tensor[1]): Codebook perplexity metric |
|
active_num (Tensor[1]): Number of active codebook entries |
|
""" |
|
|
|
B, _, T = z.shape |
|
|
|
|
|
z_e = self.in_project(z) |
|
|
|
|
|
z_q, indices, dists = self.decode_latents(z_e) |
|
|
|
|
|
with torch.no_grad(): |
|
embed_onehot = F.one_hot(indices, self.codebook_size).type(z_e.dtype) |
|
|
|
embed_onehot_flat = rearrange(embed_onehot, 'b t c -> (b t) c') |
|
avg_probs = torch.mean(embed_onehot_flat, dim=0) |
|
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) |
|
|
|
|
|
active_num_tensor = (embed_onehot_flat.sum(0) > 0).sum() |
|
if self.training: |
|
|
|
ema_inplace(self.cluster_size, embed_onehot_flat.sum(0), self.decay) |
|
|
|
active_num_tensor = (self.cluster_size > self.threshold_ema_dead_code).sum() |
|
|
|
|
|
|
|
commit_loss = torch.tensor(0.0, device=z.device) |
|
codebook_loss = torch.tensor(0.0, device=z.device) |
|
vq_loss = torch.tensor(0.0, device=z.device) |
|
|
|
if self.training: |
|
|
|
|
|
commit_loss = F.mse_loss(z_e, z_q.detach()) * self.commitment |
|
|
|
|
|
|
|
codebook_loss = F.mse_loss(z_q, z_e.detach()) * self.codebook_loss_weight |
|
|
|
vq_loss = commit_loss + codebook_loss |
|
|
|
|
|
|
|
z_q_st = z_e + (z_q - z_e).detach() |
|
|
|
|
|
z_q_out = self.out_project(z_q_st) |
|
|
|
return { |
|
"z_q": z_q_out, |
|
"indices": indices, |
|
|
|
"vq_loss": vq_loss, |
|
"perplexity": perplexity, |
|
"active_num": active_num_tensor.float(), |
|
} |
|
|
|
def embed_code(self, embed_id): |
|
"""Retrieve codebook vectors for given indices.""" |
|
return F.embedding(embed_id, self.codebook.weight) |
|
|
|
def decode_code(self, embed_id): |
|
"""Retrieve codebook vectors and transpose to (B, D, T) format.""" |
|
|
|
|
|
|
|
return self.embed_code(embed_id).transpose(1, 2) |
|
|
|
def decode_latents(self, latents): |
|
"""Find nearest codebook entries for latent vectors.""" |
|
|
|
B, D_code, T = latents.shape |
|
encodings = rearrange(latents, "b d t -> (b t) d") |
|
codebook = self.codebook.weight |
|
|
|
|
|
if self.use_l2_normlize: |
|
encodings = F.normalize(encodings, p=2, dim=-1) |
|
codebook = F.normalize(codebook, p=2, dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
dist = ( |
|
encodings.pow(2).sum(1, keepdim=True) |
|
- 2 * (encodings @ codebook.t()) |
|
+ codebook.pow(2).sum(1, keepdim=True).t() |
|
) |
|
|
|
|
|
indices = torch.argmin(dist, dim=-1) |
|
indices = rearrange(indices, "(b t) -> b t", b=B) |
|
|
|
|
|
z_q = self.decode_code(indices) |
|
|
|
return z_q, indices, dist |
|
|
|
|
|
def tokenize(self, z: torch.Tensor) -> torch.Tensor: |
|
"""Tokenize the input tensor without loss calculation.""" |
|
|
|
z_e = self.in_project(z) |
|
_, indices, _ = self.decode_latents(z_e) |
|
return indices |
|
|
|
def detokenize(self, indices: torch.Tensor) -> torch.Tensor: |
|
"""Detokenize indices to quantized vectors in input dimension.""" |
|
|
|
z_q_code_dim = self.decode_code(indices) |
|
z_q_out = self.out_project(z_q_code_dim) |
|
return z_q_out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BiCodec(nn.Module): |
|
""" |
|
BiCodec model for speech synthesis, incorporating a speaker encoder, feature encoder/decoder, |
|
quantizer, and wave generator. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
mel_params: Dict[str, Any], |
|
encoder: nn.Module, |
|
decoder: nn.Module, |
|
quantizer: nn.Module, |
|
speaker_encoder: nn.Module, |
|
prenet: nn.Module, |
|
postnet: nn.Module, |
|
**kwargs |
|
) -> None: |
|
""" |
|
Initializes the BiCodec model with the required components. |
|
|
|
Args: |
|
mel_params (dict): Parameters for the mel-spectrogram transformer. |
|
encoder (nn.Module): Encoder module. |
|
decoder (nn.Module): Decoder module. |
|
quantizer (nn.Module): Quantizer module. |
|
speaker_encoder (nn.Module): Speaker encoder module. |
|
prenet (nn.Module): Prenet network. |
|
postnet (nn.Module): Postnet network. |
|
""" |
|
super().__init__() |
|
self.encoder = encoder |
|
self.decoder = decoder |
|
self.quantizer = quantizer |
|
self.speaker_encoder = speaker_encoder |
|
self.prenet = prenet |
|
self.postnet = postnet |
|
self._init_mel_transformer(mel_params) |
|
|
|
@classmethod |
|
def load_from_config_and_checkpoint(cls, model_dir: Path, config_dict: Dict[str, Any], **kwargs) -> "BiCodec": |
|
"""Loads the model from a config dictionary and checkpoint file.""" |
|
ckpt_path = model_dir / 'model.safetensors' |
|
if not ckpt_path.is_file(): |
|
raise FileNotFoundError(f"BiCodec checkpoint not found at {ckpt_path}") |
|
|
|
audio_tokenizer_config = config_dict |
|
|
|
|
|
mel_params = audio_tokenizer_config.get("mel_params", {}) |
|
encoder_cfg = audio_tokenizer_config.get("encoder", {}) |
|
quantizer_cfg = audio_tokenizer_config.get("quantizer", {}) |
|
prenet_cfg = audio_tokenizer_config.get("prenet", {}) |
|
postnet_cfg = audio_tokenizer_config.get("postnet", {}) |
|
decoder_cfg = audio_tokenizer_config.get("decoder", {}) |
|
speaker_encoder_cfg = audio_tokenizer_config.get("speaker_encoder", {}) |
|
|
|
|
|
required_keys = { |
|
"encoder": ["input_channels", "vocos_dim", "vocos_intermediate_dim", "vocos_num_layers", "out_channels"], |
|
"quantizer": ["input_dim", "codebook_size", "codebook_dim", "commitment"], |
|
"prenet": ["input_channels", "vocos_dim", "vocos_intermediate_dim", "vocos_num_layers", "out_channels"], |
|
"postnet": ["input_channels", "vocos_dim", "vocos_intermediate_dim", "vocos_num_layers", "out_channels"], |
|
"decoder": ["input_channel", "channels", "rates", "kernel_sizes"], |
|
"speaker_encoder": ["input_dim", "out_dim", "latent_dim", "token_num"], |
|
"mel_params": ["sample_rate", "n_fft", "win_length", "hop_length", "num_mels"] |
|
} |
|
for comp, keys in required_keys.items(): |
|
cfg = audio_tokenizer_config.get(comp, {}) |
|
if not cfg: logging.get_logger(__name__).warning(f"BiCodec config missing section: '{comp}'") |
|
for key in keys: |
|
if key not in cfg: |
|
logging.get_logger(__name__).warning(f"BiCodec config missing key '{key}' in section '{comp}'") |
|
|
|
|
|
|
|
|
|
encoder = Encoder(**encoder_cfg) if encoder_cfg else None |
|
quantizer = FactorizedVectorQuantize(**quantizer_cfg) if quantizer_cfg else None |
|
prenet = Decoder(**prenet_cfg) if prenet_cfg else None |
|
postnet = Decoder(**postnet_cfg) if postnet_cfg else None |
|
decoder = WaveGenerator(**decoder_cfg) if decoder_cfg else None |
|
speaker_encoder = SpeakerEncoder(**speaker_encoder_cfg) if speaker_encoder_cfg else None |
|
|
|
|
|
if not all([encoder, quantizer, prenet, postnet, decoder, speaker_encoder, mel_params]): |
|
raise ValueError("Failed to initialize one or more BiCodec components due to missing configuration.") |
|
|
|
|
|
model = cls( |
|
mel_params=mel_params, |
|
encoder=encoder, |
|
decoder=decoder, |
|
quantizer=quantizer, |
|
speaker_encoder=speaker_encoder, |
|
prenet=prenet, |
|
postnet=postnet, |
|
) |
|
|
|
|
|
try: |
|
state_dict = load_file(ckpt_path, device="cpu") |
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) |
|
if missing_keys: |
|
print(f"BiCodec missing keys: {missing_keys}") |
|
if unexpected_keys: |
|
print(f"BiCodec unexpected keys: {unexpected_keys}") |
|
except Exception as e: |
|
raise IOError(f"Error loading BiCodec state dict from {ckpt_path}: {e}") |
|
|
|
model.eval() |
|
|
|
|
|
return model |
|
|
|
def _init_mel_transformer(self, config: Dict[str, Any]): |
|
|
|
sr = config.get("sample_rate", 16000) |
|
n_fft = config.get("n_fft", 1024) |
|
win_length = config.get("win_length", n_fft) |
|
hop_length = config.get("hop_length", n_fft // 4) |
|
fmin = config.get("mel_fmin", 0) |
|
fmax = config.get("mel_fmax", None) |
|
n_mels = config.get("num_mels", 80) |
|
power = config.get("power", 2.0) |
|
norm = config.get("norm", "slaney") |
|
mel_scale = config.get("mel_scale", "htk") |
|
|
|
self.mel_transformer = TT.MelSpectrogram( |
|
sample_rate=sr, |
|
n_fft=n_fft, |
|
win_length=win_length, |
|
hop_length=hop_length, |
|
f_min=fmin, |
|
f_max=fmax, |
|
n_mels=n_mels, |
|
power=power, |
|
norm=norm, |
|
mel_scale=mel_scale, |
|
).eval() |
|
|
|
def remove_weight_norm(self): |
|
"""Removes weight normalization from components that support it.""" |
|
def _remove_wn(m): |
|
if hasattr(m, 'remove_weight_norm'): |
|
m.remove_weight_norm() |
|
elif isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)): |
|
try: |
|
remove_weight_norm(m) |
|
except ValueError: |
|
pass |
|
|
|
self.apply(_remove_wn) |
|
|
|
|
|
@torch.no_grad() |
|
def tokenize(self, feat: torch.Tensor, ref_wav: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" Tokenizes input features and reference wav into semantic and global tokens. """ |
|
|
|
device = feat.device |
|
self.mel_transformer.to(device) |
|
self.encoder.to(device) |
|
self.quantizer.to(device) |
|
self.speaker_encoder.to(device) |
|
|
|
|
|
mel = self.mel_transformer(ref_wav) |
|
|
|
|
|
z = self.encoder(feat) |
|
|
|
|
|
semantic_tokens = self.quantizer.tokenize(z) |
|
|
|
|
|
|
|
global_tokens = self.speaker_encoder.tokenize(mel.transpose(1, 2)) |
|
|
|
|
|
|
|
return global_tokens, semantic_tokens |
|
|
|
|
|
@torch.no_grad() |
|
def detokenize(self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor) -> torch.Tensor: |
|
""" Detokenizes semantic and global tokens into a waveform. """ |
|
|
|
device = semantic_tokens.device |
|
self.quantizer.to(device) |
|
self.speaker_encoder.to(device) |
|
self.prenet.to(device) |
|
self.decoder.to(device) |
|
|
|
|
|
|
|
|
|
|
|
z_q = self.quantizer.detokenize(semantic_tokens) |
|
|
|
|
|
|
|
d_vector = self.speaker_encoder.detokenize(global_tokens) |
|
|
|
|
|
|
|
x = self.prenet(z_q, d_vector) |
|
|
|
|
|
|
|
if d_vector.ndim == 2: |
|
d_vector_unsqueezed = d_vector.unsqueeze(-1) |
|
else: |
|
d_vector_unsqueezed = d_vector |
|
|
|
|
|
if x.shape[1] == d_vector_unsqueezed.shape[1]: |
|
|
|
x = x + d_vector_unsqueezed |
|
else: |
|
|
|
logging.get_logger(__name__).warning(f"Prenet output dim {x.shape[1]} != d-vector dim {d_vector_unsqueezed.shape[1]}. Skipping residual connection.") |
|
|
|
|
|
|
|
|
|
wav_recon = self.decoder(x) |
|
|
|
return wav_recon |
|
|
|
|
|
|
|
from .configuration_spark_tts import SparkTTSConfig |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
class SparkTTSModel(PreTrainedModel, GenerationMixin): |
|
""" |
|
SparkTTS model integrating LLM, BiCodec, and Wav2Vec2 for text-to-speech. |
|
""" |
|
config_class = SparkTTSConfig |
|
base_model_prefix = "spark_tts" |
|
_supports_load_fast = False |
|
|
|
def __init__(self, config: SparkTTSConfig, llm=None, wav2vec2_model=None, wav2vec2_processor=None, bicodec=None): |
|
super().__init__(config) |
|
self.config = config |
|
self.llm = llm |
|
self.wav2vec2_model = wav2vec2_model |
|
self.wav2vec2_processor = wav2vec2_processor |
|
self.bicodec = bicodec |
|
|
|
|
|
if self.wav2vec2_model and hasattr(self.wav2vec2_model.config, 'output_hidden_states'): |
|
self.wav2vec2_model.config.output_hidden_states = True |
|
|
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], |
|
*model_args, |
|
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, |
|
cache_dir: Optional[Union[str, os.PathLike]] = None, |
|
ignore_mismatched_sizes: bool = False, |
|
force_download: bool = False, |
|
local_files_only: bool = False, |
|
token: Optional[Union[str, bool]] = None, |
|
revision: str = "main", |
|
use_safetensors: bool = None, |
|
**kwargs, |
|
): |
|
|
|
if config is None: |
|
config, model_kwargs = cls.config_class.from_pretrained( |
|
pretrained_model_name_or_path, |
|
*model_args, |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
local_files_only=local_files_only, |
|
token=token, |
|
revision=revision, |
|
return_unused_kwargs=True, |
|
**kwargs, |
|
) |
|
else: |
|
model_kwargs = kwargs |
|
|
|
|
|
device_map = model_kwargs.pop("device_map", None) |
|
torch_dtype = model_kwargs.pop("torch_dtype", "auto") |
|
|
|
|
|
trust_remote_code = model_kwargs.pop("trust_remote_code", False) |
|
|
|
|
|
if pretrained_model_name_or_path is not None: |
|
resolved_model_path = Path(pretrained_model_name_or_path) |
|
if not resolved_model_path.is_dir(): |
|
|
|
|
|
try: |
|
resolved_model_path = Path(cached_file( |
|
pretrained_model_name_or_path, |
|
filename=cls.config_class.config_files[0], |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
local_files_only=local_files_only, |
|
token=token, |
|
revision=revision, |
|
_raise_exceptions_for_missing_entries=False, |
|
_raise_exceptions_for_connection_errors=False, |
|
)).parent |
|
except Exception as e: |
|
logger.warning(f"Could not resolve cache path for {pretrained_model_name_or_path}: {e}. Assuming it's a local path.") |
|
resolved_model_path = Path(pretrained_model_name_or_path) |
|
if not resolved_model_path.is_dir(): |
|
raise EnvironmentError(f"Cannot find model directory at {resolved_model_path}") |
|
else: |
|
raise ValueError("pretrained_model_name_or_path must be provided.") |
|
|
|
|
|
|
|
def _resolve_path(sub_path): |
|
p = Path(sub_path) |
|
if p.is_absolute(): |
|
return str(p) |
|
else: |
|
|
|
return str(resolved_model_path / p) |
|
|
|
|
|
llm_path = _resolve_path(config.llm_model_name_or_path) |
|
logger.info(f"Loading LLM from resolved path: {llm_path}") |
|
try: |
|
llm = AutoModelForCausalLM.from_pretrained( |
|
llm_path, |
|
torch_dtype=torch_dtype if torch_dtype != "auto" else config.torch_dtype, |
|
trust_remote_code=trust_remote_code, |
|
**model_kwargs |
|
) |
|
except Exception as e: |
|
raise OSError(f"Failed to load LLM from {llm_path}: {e}") |
|
|
|
|
|
w2v_path = _resolve_path(config.wav2vec2_model_name_or_path) |
|
logger.info(f"Loading Wav2Vec2 from resolved path: {w2v_path}") |
|
try: |
|
wav2vec2_processor = Wav2Vec2FeatureExtractor.from_pretrained(w2v_path, trust_remote_code=trust_remote_code) |
|
wav2vec2_model = Wav2Vec2Model.from_pretrained(w2v_path, trust_remote_code=trust_remote_code) |
|
except Exception as e: |
|
raise OSError(f"Failed to load Wav2Vec2 components from {w2v_path}: {e}") |
|
|
|
|
|
bicodec_path = _resolve_path(config.bicodec_model_name_or_path) |
|
logger.info(f"Loading BiCodec from resolved path: {bicodec_path}") |
|
|
|
if not config.bicodec_config or "audio_tokenizer" not in config.bicodec_config: |
|
raise ValueError("BiCodec configuration ('bicodec_config' with 'audio_tokenizer' key) not found in SparkTTSConfig.") |
|
try: |
|
|
|
bicodec = BiCodec.load_from_config_and_checkpoint( |
|
model_dir=Path(bicodec_path), |
|
config_dict=config.bicodec_config["audio_tokenizer"] |
|
) |
|
except Exception as e: |
|
raise OSError(f"Failed to load BiCodec from {bicodec_path}: {e}") |
|
|
|
|
|
|
|
model = cls(config, llm=llm, wav2vec2_model=wav2vec2_model, wav2vec2_processor=wav2vec2_processor, bicodec=bicodec) |
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
current_device = torch.cuda.current_device() |
|
device = torch.device(f"cuda:{current_device}") |
|
else: |
|
device = torch.device("cpu") |
|
logger.info(f"Placing SparkTTSModel and components on device: {device}") |
|
model.to(device) |
|
|
|
return model |
|
|
|
|
|
|
|
def get_input_embeddings(self): |
|
if self.llm: |
|
return self.llm.get_input_embeddings() |
|
return None |
|
|
|
def set_input_embeddings(self, value): |
|
if self.llm: |
|
self.llm.set_input_embeddings(value) |
|
else: |
|
logger.warning("LLM not loaded, cannot set input embeddings.") |
|
|
|
def get_output_embeddings(self): |
|
if self.llm: |
|
|
|
return self.llm.get_output_embeddings() |
|
return None |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
if self.llm and hasattr(self.llm, 'set_output_embeddings'): |
|
self.llm.set_output_embeddings(new_embeddings) |
|
else: |
|
logger.warning("LLM not loaded or does not support set_output_embeddings.") |
|
|
|
|
|
|
|
|
|
def post_init(self): |
|
|
|
if self.wav2vec2_model and hasattr(self.wav2vec2_model.config, 'output_hidden_states'): |
|
if not self.wav2vec2_model.config.output_hidden_states: |
|
self.wav2vec2_model.config.output_hidden_states = True |
|
logger.info("Set wav2vec2_model.config.output_hidden_states=True") |
|
|
|
@property |
|
def device(self) -> torch.device: |
|
""" Override device property to report the LLM's device as representative """ |
|
if self.llm: |
|
return self.llm.device |
|
else: |
|
|
|
|
|
try: |
|
return next(self.parameters()).device |
|
except StopIteration: |
|
|
|
return torch.device("cpu") |
|
|
|
@torch.no_grad() |
|
def _extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor: |
|
"""Extract wav2vec2 features. Input wavs: (B, T_wav)""" |
|
if not self.wav2vec2_model or not self.wav2vec2_processor: |
|
raise RuntimeError("Wav2Vec2 components not loaded.") |
|
|
|
|
|
target_device = self.wav2vec2_model.device |
|
wavs_on_device = wavs.to(target_device) |
|
|
|
|
|
processor_output = self.wav2vec2_processor( |
|
wavs_on_device, |
|
sampling_rate=self.config.sample_rate, |
|
return_tensors="pt", |
|
padding=True, |
|
) |
|
inputs = processor_output.input_values |
|
|
|
print(f"Shape returned by processor: {inputs.shape}") |
|
|
|
|
|
if inputs.ndim == 4 and inputs.shape[1] == 1 and inputs.shape[2] == 1: |
|
print(f"Reshaping input from {inputs.shape} to 2D.") |
|
inputs = inputs.squeeze(1).squeeze(1) |
|
elif inputs.ndim == 3 and inputs.shape[1] == 1: |
|
print(f"Reshaping input from {inputs.shape} to 2D.") |
|
inputs = inputs.squeeze(1) |
|
|
|
|
|
if inputs.ndim != 2: |
|
raise ValueError(f"Unexpected shape after processing/reshaping: {inputs.shape}. Expected 2D input for Wav2Vec2Model.") |
|
|
|
print(f"Shape BEFORE Wav2Vec2Model: {inputs.shape}") |
|
|
|
|
|
inputs = inputs.to(target_device) |
|
|
|
outputs = self.wav2vec2_model(inputs, output_hidden_states=True) |
|
|
|
if outputs.hidden_states is None: |
|
raise ValueError("Wav2Vec2 model did not return hidden states. Ensure config.output_hidden_states=True.") |
|
|
|
|
|
num_layers = len(outputs.hidden_states) |
|
indices_to_mix = [11, 14, 16] |
|
valid_indices = [i for i in indices_to_mix if i < num_layers] |
|
|
|
if len(valid_indices) != len(indices_to_mix): |
|
logger.warning(f"Requested Wav2Vec2 hidden state indices {indices_to_mix} out of range (0-{num_layers-1}). Using available valid indices: {valid_indices}.") |
|
if not valid_indices: |
|
logger.warning("No valid hidden state indices for mixing. Using last hidden state.") |
|
feats_mix = outputs.last_hidden_state |
|
else: |
|
|
|
feats_mix = torch.stack([outputs.hidden_states[i] for i in valid_indices]).mean(dim=0) |
|
else: |
|
|
|
feats_mix = (outputs.hidden_states[11] + outputs.hidden_states[14] + outputs.hidden_states[16]) / 3 |
|
|
|
|
|
return feats_mix.transpose(1, 2) |
|
|
|
def _get_ref_clip(self, wav: np.ndarray) -> np.ndarray: |
|
"""Get reference audio clip for speaker embedding.""" |
|
ref_samples = int(self.config.sample_rate * self.config.ref_segment_duration) |
|
latent_hop_length = self.config.latent_hop_length |
|
|
|
ref_segment_length = max(latent_hop_length, (ref_samples // latent_hop_length) * latent_hop_length) |
|
|
|
wav_length = len(wav) |
|
|
|
if wav_length == 0: |
|
return np.zeros(ref_segment_length, dtype=np.float32) |
|
if ref_segment_length > wav_length: |
|
num_repeats = (ref_segment_length // wav_length) + 1 |
|
wav = np.tile(wav, num_repeats) |
|
|
|
return wav[:ref_segment_length].astype(np.float32) |
|
|
|
|
|
@torch.no_grad() |
|
def _tokenize_audio(self, audio_path: str) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Load audio, extract features, and tokenize using BiCodec.""" |
|
wav_np = load_audio( |
|
audio_path, |
|
sampling_rate=self.config.sample_rate, |
|
volume_normalize=self.config.volume_normalize, |
|
) |
|
wav_ref_np = self._get_ref_clip(wav_np) |
|
|
|
|
|
wav = torch.from_numpy(wav_np).unsqueeze(0).float().to(self.device) |
|
ref_wav = torch.from_numpy(wav_ref_np).unsqueeze(0).float().to(self.device) |
|
|
|
|
|
feat = self._extract_wav2vec2_features(wav) |
|
|
|
|
|
|
|
global_tokens, semantic_tokens = self.bicodec.tokenize(feat, ref_wav) |
|
|
|
|
|
return global_tokens, semantic_tokens |
|
|
|
@torch.no_grad() |
|
def _detokenize_audio(self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor) -> np.ndarray: |
|
"""Detokenize using BiCodec to get waveform.""" |
|
global_tokens = global_tokens.to(self.device) |
|
semantic_tokens = semantic_tokens.to(self.device) |
|
self.bicodec.to(self.device) |
|
|
|
|
|
wav_rec = self.bicodec.detokenize(global_tokens, semantic_tokens) |
|
|
|
return wav_rec.detach().squeeze(0).squeeze(0).cpu().numpy() |
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
""" Forward pass delegates to the LLM for generation compatibility, but direct use is not intended for TTS. """ |
|
|
|
logger.warning("Direct forward pass on SparkTTSModel is not the intended use for TTS. Use the generate method or pipeline.") |
|
|
|
if 'input_ids' in kwargs: |
|
return self.llm(input_ids=kwargs['input_ids']) |
|
else: |
|
raise NotImplementedError("SparkTTSModel's forward pass requires 'input_ids' or should not be called directly for TTS.") |
|
|
|
|
|
|
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, **kwargs): |
|
""" Prepares inputs for the LLM's generate method. """ |
|
if not self.llm: |
|
raise RuntimeError("LLM component not loaded.") |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
model_inputs = self.llm.prepare_inputs_for_generation(input_ids, **kwargs) |
|
return model_inputs |
|
except AttributeError: |
|
|
|
logger.warning("LLM does not have 'prepare_inputs_for_generation'. Using basic fallback.") |
|
model_kwargs = {} |
|
model_kwargs["past_key_values"] = kwargs.get("past_key_values", None) |
|
model_kwargs["use_cache"] = kwargs.get("use_cache", None) |
|
|
|
if "attention_mask" in kwargs: |
|
model_kwargs["attention_mask"] = kwargs["attention_mask"] |
|
return {"input_ids": input_ids, **model_kwargs} |
|
|
|
|
|
|
|
|
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
**kwargs |
|
) -> Any: |
|
""" |
|
Minimal forward pass that delegates to the underlying LLM. |
|
Required for compatibility with GenerationMixin. |
|
Accepts arguments typically returned by prepare_inputs_for_generation. |
|
""" |
|
if not self.llm: |
|
raise RuntimeError("LLM component not loaded.") |
|
|
|
|
|
|
|
llm_kwargs = { |
|
"past_key_values": past_key_values, |
|
"attention_mask": attention_mask, |
|
**kwargs |
|
} |
|
|
|
|
|
|
|
|
|
if position_ids is not None: |
|
llm_kwargs["position_ids"] = position_ids |
|
|
|
return self.llm(input_ids=input_ids, **llm_kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|