gooorillax's picture
first push of codes and models for g2p, t2u, tokenizer and detokenizer
cd8454d
# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Dehua Tao)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
import torch
from torch import nn
import torch.nn.functional as F
from torchaudio.models import Conformer
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
TimestepEmbedding,
ConvNeXtV2Block,
ConvPositionEmbedding,
AdaLayerNormZero_Final,
precompute_freqs_cis,
get_pos_embed_indices,
)
from model.modules import CADiTBlock
import logging
# Text embedding
class TextEmbedding(nn.Module):
def __init__(
self,
text_num_embeds,
text_dim,
should_extend_text=True,
conv_layers=0,
conv_mult=2,
):
super().__init__()
self.text_embed = nn.Embedding(
text_num_embeds + 1, text_dim
) # use 0 as filler token
self.should_extend_text = should_extend_text
logging.info(f"should_extend_text={should_extend_text}")
if conv_layers > 0:
self.extra_modeling = True
self.precompute_max_pos = 4096 # ~44s of 24khz audio
self.register_buffer(
"freqs_cis",
precompute_freqs_cis(text_dim, self.precompute_max_pos),
persistent=False,
)
self.text_blocks = nn.Sequential(
*[
ConvNeXtV2Block(text_dim, text_dim * conv_mult)
for _ in range(conv_layers)
]
)
# # Can be deleted
# self.text_blocks = Conformer(
# input_dim=text_dim,
# num_heads=8, # Not sure it is good
# ffn_dim=text_dim * conv_mult,
# num_layers=conv_layers,
# depthwise_conv_kernel_size=7, # See ConvNeXtV2Block
# )
else:
self.extra_modeling = False
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
text = (
text + 1
) # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
text = text[
:, :seq_len
] # curtail if character tokens are more than the mel spec tokens
batch, text_len = text.shape[0], text.shape[1]
if self.should_extend_text:
text = F.pad(text, (0, seq_len - text_len), value=0)
else:
seq_len = text_len
if drop_text: # cfg for text
text = torch.zeros_like(text)
text = self.text_embed(text) # b n -> b n d
# possible extra modeling
if self.extra_modeling:
# sinus pos emb
batch_start = torch.zeros((batch,), dtype=torch.long)
pos_idx = get_pos_embed_indices(
batch_start, seq_len, max_pos=self.precompute_max_pos
)
text_pos_embed = self.freqs_cis[pos_idx]
text = text + text_pos_embed
# convnextv2 blocks
text = self.text_blocks(text)
# # Can be deleted
# # conformer blocks
# lengths = torch.Tensor([text.size(1)] * text.size(0)).to(text.device)
# text, _ = self.text_blocks(text, lengths)
return text
# noised input audio embedding
class InputAudioEmbedding(nn.Module):
def __init__(self, mel_dim, out_dim):
super().__init__()
self.proj = nn.Linear(mel_dim * 2, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
def forward(
self,
x: float["b n d"],
cond: float["b n d"],
drop_audio_cond=False,
): # noqa: F722
if drop_audio_cond: # cfg for cond audio
cond = torch.zeros_like(cond)
x = self.proj(torch.cat((x, cond), dim=-1))
x = self.conv_pos_embed(x) + x
return x
# Transformer backbone using cross-attention DiT blocks
class CADiT(nn.Module):
def __init__(
self,
*,
dim,
depth=8,
heads=8,
dim_head=64,
dropout=0.1,
ff_mult=4,
mel_dim=100,
text_num_embeds=256,
text_dim=None,
should_extend_text=True,
conv_layers=0,
long_skip_connection=False,
checkpoint_activations=False,
):
super().__init__()
self.time_embed = TimestepEmbedding(dim)
if text_dim is None:
text_dim = mel_dim
self.text_embed = TextEmbedding(
text_num_embeds,
text_dim,
should_extend_text=should_extend_text,
conv_layers=conv_layers,
)
# Modification: only concatenate noisy and masked speech
self.input_embed = InputAudioEmbedding(mel_dim, dim)
self.rotary_embed = RotaryEmbedding(dim_head)
self.dim = dim
self.depth = depth
# Modification: use cross-attention DiT block
self.transformer_blocks = nn.ModuleList(
[
CADiTBlock(
dim=dim,
text_dim=text_dim,
heads=heads,
dim_head=dim_head,
ff_mult=ff_mult,
dropout=dropout,
)
for _ in range(depth)
]
)
self.long_skip_connection = (
nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
)
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)
self.checkpoint_activations = checkpoint_activations
def ckpt_wrapper(self, module):
# https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
def ckpt_forward(*inputs):
outputs = module(*inputs)
return outputs
return ckpt_forward
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool["b n"] | None = None, # noqa: F722
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning time, x: noised input audio
t = self.time_embed(time)
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
x = self.input_embed(x, cond, drop_audio_cond=drop_audio_cond)
rope = self.rotary_embed.forward_from_seq_len(seq_len)
if self.long_skip_connection is not None:
residual = x
for block in self.transformer_blocks:
if self.checkpoint_activations:
x = torch.utils.checkpoint.checkpoint(
self.ckpt_wrapper(block), x, text_embed, t, mask, rope
)
else:
x = block(x, text_embed, t, mask=mask, rope=rope)
if self.long_skip_connection is not None:
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
x = self.norm_out(x, t)
output = self.proj_out(x)
return output