Spark-TTS-0.5B / sparktts /modules /vq /factorized_vector_quantize.py
mrfakename's picture
Upload 43 files
d93aca0 verified
raw
history blame
6.41 kB
# Copyright (c) 2025 SparkAudio
# 2025 Xinsheng Wang (w.xinshawn@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Heavily based on https://github.com/lucidrains/vector-quantize-pytorch
from typing import Any, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils import weight_norm
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def ema_inplace(moving_avg, new, decay):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
class FactorizedVectorQuantize(nn.Module):
def __init__(
self,
input_dim: int,
codebook_size: int,
codebook_dim: int,
commitment: float,
codebook_loss_weight: float = 1.0,
decay: float = 0.99,
threshold_ema_dead_code: float = 2,
momentum: float = 0.99,
**kwargs,
):
super().__init__()
self.input_dim = input_dim
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.commitment = commitment
self.codebook_loss_weight = codebook_loss_weight
self.decay = decay
self.threshold_ema_dead_code = threshold_ema_dead_code
self.momentum = momentum
if input_dim != self.codebook_dim:
self.in_project = WNConv1d(input_dim, self.codebook_dim, kernel_size=1)
self.out_project = WNConv1d(self.codebook_dim, input_dim, kernel_size=1)
else:
self.in_project = nn.Identity()
self.out_project = nn.Identity()
self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
self.register_buffer("cluster_size", torch.zeros(self.codebook_size))
def forward(self, z: torch.Tensor) -> Dict[str, Any]:
"""Quantized the input tensor using a fixed codebook and returns
the corresponding codebook vectors
Parameters
----------
z : Tensor[B x D x T]
Returns
-------
Tensor[B x D x T]
Quantized continuous representation of input
Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
Tensor[1]
Codebook loss to update the codebook
Tensor[B x T]
Codebook indices (quantized discrete representation of input)
Tensor[B x D x T]
Projected latents (continuous representation of input before quantization)
"""
# transpose since we use linear
# Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
z_e = self.in_project(z)
z_q, indices, dists = self.decode_latents(z_e)
# statistic the usage of codes
embed_onehot = F.one_hot(indices, self.codebook_size).type(z_e.dtype)
avg_probs = torch.mean(embed_onehot.reshape(-1, self.codebook_size), dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
active_num = (embed_onehot.sum(0).sum(0) > 0).sum()
if self.training:
# We do the expiry of code at that point as buffers are in sync
# and all the workers will take the same decision.
ema_inplace(self.cluster_size, embed_onehot.sum(0).sum(0), self.decay)
active_num = sum(self.cluster_size > self.threshold_ema_dead_code)
if self.training:
commit_loss = (
F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
* self.commitment
)
codebook_loss = (
F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
* self.codebook_loss_weight
)
else:
commit_loss = torch.zeros(0, device=z.device)
codebook_loss = torch.zeros(0, device=z.device)
z_q = (
z_e + (z_q - z_e).detach()
) # noop in forward pass, straight-through gradient estimator in backward pass
z_q = self.out_project(z_q)
vq_loss = (commit_loss + codebook_loss).mean()
return {
"z_q": z_q,
"indices": indices,
"dists": dists,
"vq_loss": vq_loss,
"perplexity": perplexity,
"active_num": active_num.float(),
}
def vq2emb(self, vq, out_proj=True):
emb = self.embed_code(vq)
if out_proj:
emb = self.out_project(emb)
return emb
def tokenize(self, z: torch.Tensor) -> torch.Tensor:
"""tokenize the input tensor"""
z_e = self.in_project(z)
_, indices, _ = self.decode_latents(z_e)
return indices
def detokenize(self, indices):
"""detokenize the input indices"""
z_q = self.decode_code(indices)
z_q = self.out_project(z_q)
return z_q
def get_emb(self):
return self.codebook.weight
def embed_code(self, embed_id):
return F.embedding(embed_id, self.codebook.weight)
def decode_code(self, embed_id):
return self.embed_code(embed_id).transpose(1, 2)
def decode_latents(self, latents):
encodings = rearrange(latents, "b d t -> (b t) d")
codebook = self.codebook.weight
# L2 normalize encodings and codebook
encodings = F.normalize(encodings)
codebook = F.normalize(codebook)
# Compute euclidean distance between encodings and codebook,
# with L2 normalization, the distance is equal to cosine distance
dist = (
encodings.pow(2).sum(1, keepdim=True)
- 2 * encodings @ codebook.t()
+ codebook.pow(2).sum(1, keepdim=True).t()
)
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
z_q = self.decode_code(indices)
return z_q, indices, dist