Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from typing import Literal | |
| from f5_tts.model.modules import MelSpec | |
| from f5_tts.model.utils import ( | |
| default, | |
| exists, | |
| lens_to_mask, | |
| ) | |
| from x_transformers.x_transformers import RotaryEmbedding | |
| from f5_tts.model.modules import ( | |
| ConvPositionEmbedding, | |
| Attention, | |
| AttnProcessor, | |
| FeedForward | |
| ) | |
| class SpeedPredictorLayer(nn.Module): | |
| def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, qk_norm=None, pe_attn_head=None): | |
| super().__init__() | |
| self.attn = Attention( | |
| processor=AttnProcessor(pe_attn_head=pe_attn_head), | |
| dim=dim, | |
| heads=heads, | |
| dim_head=dim_head, | |
| dropout=dropout, | |
| qk_norm=qk_norm, | |
| ) | |
| self.ln1 = nn.LayerNorm(dim, elementwise_affine=True, eps=1e-6) | |
| self.ln2 = nn.LayerNorm(dim, elementwise_affine=True, eps=1e-6) | |
| self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") | |
| def forward(self, x, mask=None, rope=None): # x: noised input, t: time embedding | |
| # mha sublayer (Pre norm) | |
| x_norm_atte = self.ln1(x) | |
| attn_output = self.attn(x=x_norm_atte, mask=mask, rope=rope) | |
| x = x + attn_output | |
| # ffn sublayer (Pre norm) | |
| x_norm_ffn = self.ln2(x) | |
| ffn_output = self.ff(x=x_norm_ffn) | |
| output = x + ffn_output | |
| return output | |
| class GaussianCrossEntropyLoss(nn.Module): | |
| def __init__(self, num_classes, sigma_factor=2.0): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.sigma_factor = sigma_factor | |
| def forward(self, y_pred, y_true, device): # y_pred.shape: [b, num_classes] y_true.shape: [b] | |
| # gt | |
| centers = y_true.unsqueeze(-1) # shape: [b, 1] | |
| # 位置索引 | |
| positions = torch.arange(self.num_classes, device=device).float() # shape: [num_classes] | |
| positions = positions.expand(y_true.shape[0], -1) # shape: [b, num_classes] | |
| # sigma | |
| sigma = self.sigma_factor * torch.ones_like(y_true, device=device).float() | |
| # 高斯分布 | |
| diff = positions - centers # (c-gt).shape: [b, num_classes] | |
| y_true_soft = torch.exp(-(diff.pow(2) / (2 * sigma.pow(2).unsqueeze(-1)))) # shape: [b, num_classes] | |
| loss = -(y_true_soft * F.log_softmax(y_pred, dim=-1)).sum(dim=-1).mean() | |
| return loss | |
| class SpeedTransformer(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| depth=6, | |
| heads=8, | |
| dropout=0.1, | |
| ff_mult=4, | |
| qk_norm=None, | |
| pe_attn_head=None, | |
| mel_dim=100, | |
| num_classes=32, | |
| ): | |
| super().__init__() | |
| self.dim_head = dim // heads | |
| self.num_classes = num_classes | |
| self.mel_proj = nn.Linear(mel_dim, dim) | |
| self.conv_layer = ConvPositionEmbedding(dim=dim) | |
| self.rotary_embed = RotaryEmbedding(self.dim_head) | |
| self.transformer_blocks = nn.ModuleList([ | |
| SpeedPredictorLayer( | |
| dim=dim, | |
| heads=heads, | |
| dim_head = self.dim_head, | |
| ff_mult=ff_mult, | |
| dropout=dropout, | |
| qk_norm=qk_norm, | |
| pe_attn_head=pe_attn_head | |
| ) for _ in range(depth) | |
| ]) | |
| self.pool = nn.Sequential( | |
| nn.Linear(dim, dim), | |
| nn.Tanh(), | |
| nn.Linear(dim, 1) | |
| ) | |
| self.classifier = nn.Sequential( | |
| nn.LayerNorm(dim), | |
| nn.Linear(dim, dim), | |
| nn.GELU(), # nn.ReLU() | |
| nn.Linear(dim, num_classes) | |
| ) | |
| # self.initialize_weights() | |
| # def initialize_weights(self): | |
| def forward(self, x, lens): # x.shape = [b, seq_len, d_mel] | |
| seq_len = x.shape[1] | |
| mask = lens_to_mask(lens, length=seq_len) # shape = [b, seq_len] | |
| x = self.mel_proj(x) # shape = [b, seq_len, h] | |
| x = self.conv_layer(x, mask) # shape = [b, seq_len, h] | |
| rope = self.rotary_embed.forward_from_seq_len(seq_len) | |
| for block in self.transformer_blocks: | |
| x = block(x, mask=mask, rope=rope) # shape = [b, seq_len, h] | |
| # sequence pooling | |
| weights = self.pool(x) # shape = [b, seq_len, 1] | |
| # 将 padding 位置的 weights 设为 -inf | |
| weights.masked_fill_(~mask.unsqueeze(-1), -torch.finfo(weights.dtype).max) | |
| weights = F.softmax(weights, dim=1) # shape = [b, seq_len, 1] | |
| x = (x * weights).sum(dim=1) # shape = [b, h] | |
| output = self.classifier(x) # shape: [b, num_classes] | |
| return output | |
| class SpeedMapper: | |
| def __init__( | |
| self, | |
| num_classes: Literal[32, 72], | |
| delta: float = 0.25 | |
| ): | |
| self.num_classes = num_classes | |
| self.delta = delta | |
| self.max_speed = float(num_classes) * delta | |
| self.speed_values = torch.arange(0.25, self.max_speed + self.delta, self.delta) | |
| assert len(self.speed_values) == num_classes, f"Generated {len(self.speed_values)} classes, expected {num_classes}" | |
| def label_to_speed(self, label: torch.Tensor) -> torch.Tensor: | |
| return self.speed_values.to(label.device)[label] # label * 0.25 + 0.25 | |
| class SpeedPredictor(nn.Module): | |
| def __init__( | |
| self, | |
| speed_type: Literal["phonemes", "syllables", "words"] = "phonemes", | |
| mel_spec_kwargs: dict = dict(), | |
| arch_kwargs: dict | None = None, | |
| sigma_factor: int = 2, | |
| mel_spec_module: nn.Module | None = None, | |
| num_channels: int = 100, | |
| ): | |
| super().__init__() | |
| num_classes_map = { | |
| "phonemes": 72, | |
| "syllables": 32, | |
| "words": 32 | |
| } | |
| self.num_classes = num_classes_map[speed_type] | |
| # mel spec | |
| self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) | |
| num_channels = default(num_channels, self.mel_spec.n_mel_channels) | |
| self.num_channels = num_channels | |
| self.speed_transformer = SpeedTransformer(**arch_kwargs, num_classes=self.num_classes) | |
| self.gce = GaussianCrossEntropyLoss(num_classes=self.num_classes, sigma_factor=sigma_factor) | |
| self.speed_mapper = SpeedMapper(self.num_classes) | |
| def device(self): | |
| return next(self.parameters()).device | |
| def predict_speed(self, audio: torch.Tensor, lens: torch.Tensor | None = None): | |
| # raw wave | |
| if audio.ndim == 2: | |
| audio = self.mel_spec(audio).permute(0, 2, 1) | |
| batch, seq_len, device = *audio.shape[:2], audio.device | |
| if not exists(lens): | |
| lens = torch.full((batch,), seq_len, device=device, dtype=torch.long) | |
| logits = self.speed_transformer(audio, lens) | |
| probs = F.softmax(logits, dim=-1) | |
| pred_class = torch.argmax(probs, dim=-1) | |
| pred_speed = self.speed_mapper.label_to_speed(pred_class) | |
| return pred_speed | |
| def forward( | |
| self, | |
| inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722 | |
| speed: float["b"], # speed groundtruth | |
| lens: int["b"] | None = None, # noqa: F821 | |
| ): | |
| if inp.ndim == 2: | |
| inp = self.mel_spec(inp) | |
| inp = inp.permute(0, 2, 1) | |
| assert inp.shape[-1] == self.num_channels | |
| device = self.device | |
| pred = self.speed_transformer(inp, lens) | |
| loss = self.gce(pred, speed, device) | |
| return loss |