Gregniuki commited on
Commit
eda0fb4
·
verified ·
1 Parent(s): 0ef4a15

Delete src/f5-tts

Browse files
src/f5-tts/api.py DELETED
@@ -1,151 +0,0 @@
1
- import random
2
- import sys
3
- from importlib.resources import files
4
-
5
- import soundfile as sf
6
- import torch
7
- import tqdm
8
- from cached_path import cached_path
9
-
10
- from f5_tts.infer.utils_infer import (
11
- hop_length,
12
- infer_process,
13
- load_model,
14
- load_vocoder,
15
- preprocess_ref_audio_text,
16
- remove_silence_for_generated_wav,
17
- save_spectrogram,
18
- target_sample_rate,
19
- )
20
- from f5_tts.model import DiT, UNetT
21
- from f5_tts.model.utils import seed_everything
22
-
23
-
24
- class F5TTS:
25
- def __init__(
26
- self,
27
- model_type="F5-TTS",
28
- ckpt_file="",
29
- vocab_file="",
30
- ode_method="euler",
31
- use_ema=True,
32
- vocoder_name="vocos",
33
- local_path=None,
34
- device=None,
35
- ):
36
- # Initialize parameters
37
- self.final_wave = None
38
- self.target_sample_rate = target_sample_rate
39
- self.hop_length = hop_length
40
- self.seed = -1
41
- self.mel_spec_type = vocoder_name
42
-
43
- # Set device
44
- self.device = device or (
45
- "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
46
- )
47
-
48
- # Load models
49
- self.load_vocoder_model(vocoder_name, local_path)
50
- self.load_ema_model(model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema)
51
-
52
- def load_vocoder_model(self, vocoder_name, local_path):
53
- self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
54
-
55
- def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema):
56
- if model_type == "F5-TTS":
57
- if not ckpt_file:
58
- if mel_spec_type == "vocos":
59
- ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
60
- elif mel_spec_type == "bigvgan":
61
- ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt"))
62
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
63
- model_cls = DiT
64
- elif model_type == "E2-TTS":
65
- if not ckpt_file:
66
- ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
67
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
68
- model_cls = UNetT
69
- else:
70
- raise ValueError(f"Unknown model type: {model_type}")
71
-
72
- self.ema_model = load_model(
73
- model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
74
- )
75
-
76
- def export_wav(self, wav, file_wave, remove_silence=False):
77
- sf.write(file_wave, wav, self.target_sample_rate)
78
-
79
- if remove_silence:
80
- remove_silence_for_generated_wav(file_wave)
81
-
82
- def export_spectrogram(self, spect, file_spect):
83
- save_spectrogram(spect, file_spect)
84
-
85
- def infer(
86
- self,
87
- ref_file,
88
- ref_text,
89
- gen_text,
90
- show_info=print,
91
- progress=tqdm,
92
- target_rms=0.1,
93
- cross_fade_duration=0.15,
94
- sway_sampling_coef=-1,
95
- cfg_strength=2,
96
- nfe_step=32,
97
- speed=1.0,
98
- fix_duration=None,
99
- remove_silence=False,
100
- file_wave=None,
101
- file_spect=None,
102
- seed=-1,
103
- ):
104
- if seed == -1:
105
- seed = random.randint(0, sys.maxsize)
106
- seed_everything(seed)
107
- self.seed = seed
108
-
109
- ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
110
-
111
- wav, sr, spect = infer_process(
112
- ref_file,
113
- ref_text,
114
- gen_text,
115
- self.ema_model,
116
- self.vocoder,
117
- self.mel_spec_type,
118
- show_info=show_info,
119
- progress=progress,
120
- target_rms=target_rms,
121
- cross_fade_duration=cross_fade_duration,
122
- nfe_step=nfe_step,
123
- cfg_strength=cfg_strength,
124
- sway_sampling_coef=sway_sampling_coef,
125
- speed=speed,
126
- fix_duration=fix_duration,
127
- device=self.device,
128
- )
129
-
130
- if file_wave is not None:
131
- self.export_wav(wav, file_wave, remove_silence)
132
-
133
- if file_spect is not None:
134
- self.export_spectrogram(spect, file_spect)
135
-
136
- return wav, sr, spect
137
-
138
-
139
- if __name__ == "__main__":
140
- f5tts = F5TTS()
141
-
142
- wav, sr, spect = f5tts.infer(
143
- ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
144
- ref_text="some call me nature, others call me mother nature.",
145
- gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
146
- file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
147
- file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")),
148
- seed=-1, # random seed = -1
149
- )
150
-
151
- print("seed :", f5tts.seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/f5-tts/model/__init__.py DELETED
@@ -1,10 +0,0 @@
1
- from f5_tts.model.cfm import CFM
2
-
3
- from f5_tts.model.backbones.unett import UNetT
4
- from f5_tts.model.backbones.dit import DiT
5
- from f5_tts.model.backbones.mmdit import MMDiT
6
-
7
- from f5_tts.model.trainer import Trainer
8
-
9
-
10
- __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
 
 
 
 
 
 
 
 
 
 
 
src/f5-tts/model/backbones/README.md DELETED
@@ -1,20 +0,0 @@
1
- ## Backbones quick introduction
2
-
3
-
4
- ### unett.py
5
- - flat unet transformer
6
- - structure same as in e2-tts & voicebox paper except using rotary pos emb
7
- - update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat
8
-
9
- ### dit.py
10
- - adaln-zero dit
11
- - embedded timestep as condition
12
- - concatted noised_input + masked_cond + embedded_text, linear proj in
13
- - possible abs pos emb & convnextv2 blocks for embedded text before concat
14
- - possible long skip connection (first layer to last layer)
15
-
16
- ### mmdit.py
17
- - sd3 structure
18
- - timestep as condition
19
- - left stream: text embedded and applied a abs pos emb
20
- - right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/f5-tts/model/backbones/dit.py DELETED
@@ -1,163 +0,0 @@
1
- """
2
- ein notation:
3
- b - batch
4
- n - sequence
5
- nt - text sequence
6
- nw - raw wave length
7
- d - dimension
8
- """
9
-
10
- from __future__ import annotations
11
-
12
- import torch
13
- from torch import nn
14
- import torch.nn.functional as F
15
-
16
- from x_transformers.x_transformers import RotaryEmbedding
17
-
18
- from f5_tts.model.modules import (
19
- TimestepEmbedding,
20
- ConvNeXtV2Block,
21
- ConvPositionEmbedding,
22
- DiTBlock,
23
- AdaLayerNormZero_Final,
24
- precompute_freqs_cis,
25
- get_pos_embed_indices,
26
- )
27
-
28
-
29
- # Text embedding
30
-
31
-
32
- class TextEmbedding(nn.Module):
33
- def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
34
- super().__init__()
35
- self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
36
-
37
- if conv_layers > 0:
38
- self.extra_modeling = True
39
- self.precompute_max_pos = 4096 # ~44s of 24khz audio
40
- self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
41
- self.text_blocks = nn.Sequential(
42
- *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
43
- )
44
- else:
45
- self.extra_modeling = False
46
-
47
- def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
48
- text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
49
- text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
50
- batch, text_len = text.shape[0], text.shape[1]
51
- text = F.pad(text, (0, seq_len - text_len), value=0)
52
-
53
- if drop_text: # cfg for text
54
- text = torch.zeros_like(text)
55
-
56
- text = self.text_embed(text) # b n -> b n d
57
-
58
- # possible extra modeling
59
- if self.extra_modeling:
60
- # sinus pos emb
61
- batch_start = torch.zeros((batch,), dtype=torch.long)
62
- pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
63
- text_pos_embed = self.freqs_cis[pos_idx]
64
- text = text + text_pos_embed
65
-
66
- # convnextv2 blocks
67
- text = self.text_blocks(text)
68
-
69
- return text
70
-
71
-
72
- # noised input audio and context mixing embedding
73
-
74
-
75
- class InputEmbedding(nn.Module):
76
- def __init__(self, mel_dim, text_dim, out_dim):
77
- super().__init__()
78
- self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
79
- self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
80
-
81
- def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
82
- if drop_audio_cond: # cfg for cond audio
83
- cond = torch.zeros_like(cond)
84
-
85
- x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
86
- x = self.conv_pos_embed(x) + x
87
- return x
88
-
89
-
90
- # Transformer backbone using DiT blocks
91
-
92
-
93
- class DiT(nn.Module):
94
- def __init__(
95
- self,
96
- *,
97
- dim,
98
- depth=8,
99
- heads=8,
100
- dim_head=64,
101
- dropout=0.1,
102
- ff_mult=4,
103
- mel_dim=100,
104
- text_num_embeds=256,
105
- text_dim=None,
106
- conv_layers=0,
107
- long_skip_connection=False,
108
- ):
109
- super().__init__()
110
-
111
- self.time_embed = TimestepEmbedding(dim)
112
- if text_dim is None:
113
- text_dim = mel_dim
114
- self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
115
- self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
116
-
117
- self.rotary_embed = RotaryEmbedding(dim_head)
118
-
119
- self.dim = dim
120
- self.depth = depth
121
-
122
- self.transformer_blocks = nn.ModuleList(
123
- [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
124
- )
125
- self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
126
-
127
- self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
128
- self.proj_out = nn.Linear(dim, mel_dim)
129
-
130
- def forward(
131
- self,
132
- x: float["b n d"], # nosied input audio # noqa: F722
133
- cond: float["b n d"], # masked cond audio # noqa: F722
134
- text: int["b nt"], # text # noqa: F722
135
- time: float["b"] | float[""], # time step # noqa: F821 F722
136
- drop_audio_cond, # cfg for cond audio
137
- drop_text, # cfg for text
138
- mask: bool["b n"] | None = None, # noqa: F722
139
- ):
140
- batch, seq_len = x.shape[0], x.shape[1]
141
- if time.ndim == 0:
142
- time = time.repeat(batch)
143
-
144
- # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
145
- t = self.time_embed(time)
146
- text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
147
- x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
148
-
149
- rope = self.rotary_embed.forward_from_seq_len(seq_len)
150
-
151
- if self.long_skip_connection is not None:
152
- residual = x
153
-
154
- for block in self.transformer_blocks:
155
- x = block(x, t, mask=mask, rope=rope)
156
-
157
- if self.long_skip_connection is not None:
158
- x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
159
-
160
- x = self.norm_out(x, t)
161
- output = self.proj_out(x)
162
-
163
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/f5-tts/model/backbones/mmdit.py DELETED
@@ -1,146 +0,0 @@
1
- """
2
- ein notation:
3
- b - batch
4
- n - sequence
5
- nt - text sequence
6
- nw - raw wave length
7
- d - dimension
8
- """
9
-
10
- from __future__ import annotations
11
-
12
- import torch
13
- from torch import nn
14
-
15
- from x_transformers.x_transformers import RotaryEmbedding
16
-
17
- from f5_tts.model.modules import (
18
- TimestepEmbedding,
19
- ConvPositionEmbedding,
20
- MMDiTBlock,
21
- AdaLayerNormZero_Final,
22
- precompute_freqs_cis,
23
- get_pos_embed_indices,
24
- )
25
-
26
-
27
- # text embedding
28
-
29
-
30
- class TextEmbedding(nn.Module):
31
- def __init__(self, out_dim, text_num_embeds):
32
- super().__init__()
33
- self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
34
-
35
- self.precompute_max_pos = 1024
36
- self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
37
-
38
- def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
39
- text = text + 1
40
- if drop_text:
41
- text = torch.zeros_like(text)
42
- text = self.text_embed(text)
43
-
44
- # sinus pos emb
45
- batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
46
- batch_text_len = text.shape[1]
47
- pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
48
- text_pos_embed = self.freqs_cis[pos_idx]
49
-
50
- text = text + text_pos_embed
51
-
52
- return text
53
-
54
-
55
- # noised input & masked cond audio embedding
56
-
57
-
58
- class AudioEmbedding(nn.Module):
59
- def __init__(self, in_dim, out_dim):
60
- super().__init__()
61
- self.linear = nn.Linear(2 * in_dim, out_dim)
62
- self.conv_pos_embed = ConvPositionEmbedding(out_dim)
63
-
64
- def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722
65
- if drop_audio_cond:
66
- cond = torch.zeros_like(cond)
67
- x = torch.cat((x, cond), dim=-1)
68
- x = self.linear(x)
69
- x = self.conv_pos_embed(x) + x
70
- return x
71
-
72
-
73
- # Transformer backbone using MM-DiT blocks
74
-
75
-
76
- class MMDiT(nn.Module):
77
- def __init__(
78
- self,
79
- *,
80
- dim,
81
- depth=8,
82
- heads=8,
83
- dim_head=64,
84
- dropout=0.1,
85
- ff_mult=4,
86
- text_num_embeds=256,
87
- mel_dim=100,
88
- ):
89
- super().__init__()
90
-
91
- self.time_embed = TimestepEmbedding(dim)
92
- self.text_embed = TextEmbedding(dim, text_num_embeds)
93
- self.audio_embed = AudioEmbedding(mel_dim, dim)
94
-
95
- self.rotary_embed = RotaryEmbedding(dim_head)
96
-
97
- self.dim = dim
98
- self.depth = depth
99
-
100
- self.transformer_blocks = nn.ModuleList(
101
- [
102
- MMDiTBlock(
103
- dim=dim,
104
- heads=heads,
105
- dim_head=dim_head,
106
- dropout=dropout,
107
- ff_mult=ff_mult,
108
- context_pre_only=i == depth - 1,
109
- )
110
- for i in range(depth)
111
- ]
112
- )
113
- self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
114
- self.proj_out = nn.Linear(dim, mel_dim)
115
-
116
- def forward(
117
- self,
118
- x: float["b n d"], # nosied input audio # noqa: F722
119
- cond: float["b n d"], # masked cond audio # noqa: F722
120
- text: int["b nt"], # text # noqa: F722
121
- time: float["b"] | float[""], # time step # noqa: F821 F722
122
- drop_audio_cond, # cfg for cond audio
123
- drop_text, # cfg for text
124
- mask: bool["b n"] | None = None, # noqa: F722
125
- ):
126
- batch = x.shape[0]
127
- if time.ndim == 0:
128
- time = time.repeat(batch)
129
-
130
- # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
131
- t = self.time_embed(time)
132
- c = self.text_embed(text, drop_text=drop_text)
133
- x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
134
-
135
- seq_len = x.shape[1]
136
- text_len = text.shape[1]
137
- rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
138
- rope_text = self.rotary_embed.forward_from_seq_len(text_len)
139
-
140
- for block in self.transformer_blocks:
141
- c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
142
-
143
- x = self.norm_out(x, t)
144
- output = self.proj_out(x)
145
-
146
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/f5-tts/model/backbones/unett.py DELETED
@@ -1,219 +0,0 @@
1
- """
2
- ein notation:
3
- b - batch
4
- n - sequence
5
- nt - text sequence
6
- nw - raw wave length
7
- d - dimension
8
- """
9
-
10
- from __future__ import annotations
11
- from typing import Literal
12
-
13
- import torch
14
- from torch import nn
15
- import torch.nn.functional as F
16
-
17
- from x_transformers import RMSNorm
18
- from x_transformers.x_transformers import RotaryEmbedding
19
-
20
- from f5_tts.model.modules import (
21
- TimestepEmbedding,
22
- ConvNeXtV2Block,
23
- ConvPositionEmbedding,
24
- Attention,
25
- AttnProcessor,
26
- FeedForward,
27
- precompute_freqs_cis,
28
- get_pos_embed_indices,
29
- )
30
-
31
-
32
- # Text embedding
33
-
34
-
35
- class TextEmbedding(nn.Module):
36
- def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
37
- super().__init__()
38
- self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
39
-
40
- if conv_layers > 0:
41
- self.extra_modeling = True
42
- self.precompute_max_pos = 4096 # ~44s of 24khz audio
43
- self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
44
- self.text_blocks = nn.Sequential(
45
- *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
46
- )
47
- else:
48
- self.extra_modeling = False
49
-
50
- def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
51
- text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
52
- text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
53
- batch, text_len = text.shape[0], text.shape[1]
54
- text = F.pad(text, (0, seq_len - text_len), value=0)
55
-
56
- if drop_text: # cfg for text
57
- text = torch.zeros_like(text)
58
-
59
- text = self.text_embed(text) # b n -> b n d
60
-
61
- # possible extra modeling
62
- if self.extra_modeling:
63
- # sinus pos emb
64
- batch_start = torch.zeros((batch,), dtype=torch.long)
65
- pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
66
- text_pos_embed = self.freqs_cis[pos_idx]
67
- text = text + text_pos_embed
68
-
69
- # convnextv2 blocks
70
- text = self.text_blocks(text)
71
-
72
- return text
73
-
74
-
75
- # noised input audio and context mixing embedding
76
-
77
-
78
- class InputEmbedding(nn.Module):
79
- def __init__(self, mel_dim, text_dim, out_dim):
80
- super().__init__()
81
- self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
82
- self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
83
-
84
- def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
85
- if drop_audio_cond: # cfg for cond audio
86
- cond = torch.zeros_like(cond)
87
-
88
- x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
89
- x = self.conv_pos_embed(x) + x
90
- return x
91
-
92
-
93
- # Flat UNet Transformer backbone
94
-
95
-
96
- class UNetT(nn.Module):
97
- def __init__(
98
- self,
99
- *,
100
- dim,
101
- depth=8,
102
- heads=8,
103
- dim_head=64,
104
- dropout=0.1,
105
- ff_mult=4,
106
- mel_dim=100,
107
- text_num_embeds=256,
108
- text_dim=None,
109
- conv_layers=0,
110
- skip_connect_type: Literal["add", "concat", "none"] = "concat",
111
- ):
112
- super().__init__()
113
- assert depth % 2 == 0, "UNet-Transformer's depth should be even."
114
-
115
- self.time_embed = TimestepEmbedding(dim)
116
- if text_dim is None:
117
- text_dim = mel_dim
118
- self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
119
- self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
120
-
121
- self.rotary_embed = RotaryEmbedding(dim_head)
122
-
123
- # transformer layers & skip connections
124
-
125
- self.dim = dim
126
- self.skip_connect_type = skip_connect_type
127
- needs_skip_proj = skip_connect_type == "concat"
128
-
129
- self.depth = depth
130
- self.layers = nn.ModuleList([])
131
-
132
- for idx in range(depth):
133
- is_later_half = idx >= (depth // 2)
134
-
135
- attn_norm = RMSNorm(dim)
136
- attn = Attention(
137
- processor=AttnProcessor(),
138
- dim=dim,
139
- heads=heads,
140
- dim_head=dim_head,
141
- dropout=dropout,
142
- )
143
-
144
- ff_norm = RMSNorm(dim)
145
- ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
146
-
147
- skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None
148
-
149
- self.layers.append(
150
- nn.ModuleList(
151
- [
152
- skip_proj,
153
- attn_norm,
154
- attn,
155
- ff_norm,
156
- ff,
157
- ]
158
- )
159
- )
160
-
161
- self.norm_out = RMSNorm(dim)
162
- self.proj_out = nn.Linear(dim, mel_dim)
163
-
164
- def forward(
165
- self,
166
- x: float["b n d"], # nosied input audio # noqa: F722
167
- cond: float["b n d"], # masked cond audio # noqa: F722
168
- text: int["b nt"], # text # noqa: F722
169
- time: float["b"] | float[""], # time step # noqa: F821 F722
170
- drop_audio_cond, # cfg for cond audio
171
- drop_text, # cfg for text
172
- mask: bool["b n"] | None = None, # noqa: F722
173
- ):
174
- batch, seq_len = x.shape[0], x.shape[1]
175
- if time.ndim == 0:
176
- time = time.repeat(batch)
177
-
178
- # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
179
- t = self.time_embed(time)
180
- text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
181
- x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
182
-
183
- # postfix time t to input x, [b n d] -> [b n+1 d]
184
- x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
185
- if mask is not None:
186
- mask = F.pad(mask, (1, 0), value=1)
187
-
188
- rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
189
-
190
- # flat unet transformer
191
- skip_connect_type = self.skip_connect_type
192
- skips = []
193
- for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers):
194
- layer = idx + 1
195
-
196
- # skip connection logic
197
- is_first_half = layer <= (self.depth // 2)
198
- is_later_half = not is_first_half
199
-
200
- if is_first_half:
201
- skips.append(x)
202
-
203
- if is_later_half:
204
- skip = skips.pop()
205
- if skip_connect_type == "concat":
206
- x = torch.cat((x, skip), dim=-1)
207
- x = maybe_skip_proj(x)
208
- elif skip_connect_type == "add":
209
- x = x + skip
210
-
211
- # attention and feedforward blocks
212
- x = attn(attn_norm(x), rope=rope, mask=mask) + x
213
- x = ff(ff_norm(x)) + x
214
-
215
- assert len(skips) == 0
216
-
217
- x = self.norm_out(x)[:, 1:, :] # unpack t from x
218
-
219
- return self.proj_out(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/f5-tts/model/cfm.py DELETED
@@ -1,285 +0,0 @@
1
- """
2
- ein notation:
3
- b - batch
4
- n - sequence
5
- nt - text sequence
6
- nw - raw wave length
7
- d - dimension
8
- """
9
-
10
- from __future__ import annotations
11
-
12
- from random import random
13
- from typing import Callable
14
-
15
- import torch
16
- import torch.nn.functional as F
17
- from torch import nn
18
- from torch.nn.utils.rnn import pad_sequence
19
- from torchdiffeq import odeint
20
-
21
- from f5_tts.model.modules import MelSpec
22
- from f5_tts.model.utils import (
23
- default,
24
- exists,
25
- lens_to_mask,
26
- list_str_to_idx,
27
- list_str_to_tensor,
28
- mask_from_frac_lengths,
29
- )
30
-
31
-
32
- class CFM(nn.Module):
33
- def __init__(
34
- self,
35
- transformer: nn.Module,
36
- sigma=0.0,
37
- odeint_kwargs: dict = dict(
38
- # atol = 1e-5,
39
- # rtol = 1e-5,
40
- method="euler" # 'midpoint'
41
- ),
42
- audio_drop_prob=0.3,
43
- cond_drop_prob=0.2,
44
- num_channels=None,
45
- mel_spec_module: nn.Module | None = None,
46
- mel_spec_kwargs: dict = dict(),
47
- frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
48
- vocab_char_map: dict[str:int] | None = None,
49
- ):
50
- super().__init__()
51
-
52
- self.frac_lengths_mask = frac_lengths_mask
53
-
54
- # mel spec
55
- self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
56
- num_channels = default(num_channels, self.mel_spec.n_mel_channels)
57
- self.num_channels = num_channels
58
-
59
- # classifier-free guidance
60
- self.audio_drop_prob = audio_drop_prob
61
- self.cond_drop_prob = cond_drop_prob
62
-
63
- # transformer
64
- self.transformer = transformer
65
- dim = transformer.dim
66
- self.dim = dim
67
-
68
- # conditional flow related
69
- self.sigma = sigma
70
-
71
- # sampling related
72
- self.odeint_kwargs = odeint_kwargs
73
-
74
- # vocab map for tokenization
75
- self.vocab_char_map = vocab_char_map
76
-
77
- @property
78
- def device(self):
79
- return next(self.parameters()).device
80
-
81
- @torch.no_grad()
82
- def sample(
83
- self,
84
- cond: float["b n d"] | float["b nw"], # noqa: F722
85
- text: int["b nt"] | list[str], # noqa: F722
86
- duration: int | int["b"], # noqa: F821
87
- *,
88
- lens: int["b"] | None = None, # noqa: F821
89
- steps=32,
90
- cfg_strength=1.0,
91
- sway_sampling_coef=None,
92
- seed: int | None = None,
93
- max_duration=4096,
94
- vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
95
- no_ref_audio=False,
96
- duplicate_test=False,
97
- t_inter=0.1,
98
- edit_mask=None,
99
- ):
100
- self.eval()
101
- # raw wave
102
-
103
- if cond.ndim == 2:
104
- cond = self.mel_spec(cond)
105
- cond = cond.permute(0, 2, 1)
106
- assert cond.shape[-1] == self.num_channels
107
-
108
- cond = cond.to(next(self.parameters()).dtype)
109
-
110
- batch, cond_seq_len, device = *cond.shape[:2], cond.device
111
- if not exists(lens):
112
- lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
113
-
114
- # text
115
-
116
- if isinstance(text, list):
117
- if exists(self.vocab_char_map):
118
- text = list_str_to_idx(text, self.vocab_char_map).to(device)
119
- else:
120
- text = list_str_to_tensor(text).to(device)
121
- assert text.shape[0] == batch
122
-
123
- if exists(text):
124
- text_lens = (text != -1).sum(dim=-1)
125
- lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
126
-
127
- # duration
128
-
129
- cond_mask = lens_to_mask(lens)
130
- if edit_mask is not None:
131
- cond_mask = cond_mask & edit_mask
132
-
133
- if isinstance(duration, int):
134
- duration = torch.full((batch,), duration, device=device, dtype=torch.long)
135
-
136
- duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
137
- duration = duration.clamp(max=max_duration)
138
- max_duration = duration.amax()
139
-
140
- # duplicate test corner for inner time step oberservation
141
- if duplicate_test:
142
- test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
143
-
144
- cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
145
- cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
146
- cond_mask = cond_mask.unsqueeze(-1)
147
- step_cond = torch.where(
148
- cond_mask, cond, torch.zeros_like(cond)
149
- ) # allow direct control (cut cond audio) with lens passed in
150
-
151
- if batch > 1:
152
- mask = lens_to_mask(duration)
153
- else: # save memory and speed up, as single inference need no mask currently
154
- mask = None
155
-
156
- # test for no ref audio
157
- if no_ref_audio:
158
- cond = torch.zeros_like(cond)
159
-
160
- # neural ode
161
-
162
- def fn(t, x):
163
- # at each step, conditioning is fixed
164
- # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
165
-
166
- # predict flow
167
- pred = self.transformer(
168
- x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False
169
- )
170
- if cfg_strength < 1e-5:
171
- return pred
172
-
173
- null_pred = self.transformer(
174
- x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True
175
- )
176
- return pred + (pred - null_pred) * cfg_strength
177
-
178
- # noise input
179
- # to make sure batch inference result is same with different batch size, and for sure single inference
180
- # still some difference maybe due to convolutional layers
181
- y0 = []
182
- for dur in duration:
183
- if exists(seed):
184
- torch.manual_seed(seed)
185
- y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype))
186
- y0 = pad_sequence(y0, padding_value=0, batch_first=True)
187
-
188
- t_start = 0
189
-
190
- # duplicate test corner for inner time step oberservation
191
- if duplicate_test:
192
- t_start = t_inter
193
- y0 = (1 - t_start) * y0 + t_start * test_cond
194
- steps = int(steps * (1 - t_start))
195
-
196
- t = torch.linspace(t_start, 1, steps, device=self.device, dtype=step_cond.dtype)
197
- if sway_sampling_coef is not None:
198
- t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
199
-
200
- trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
201
-
202
- sampled = trajectory[-1]
203
- out = sampled
204
- out = torch.where(cond_mask, cond, out)
205
-
206
- if exists(vocoder):
207
- out = out.permute(0, 2, 1)
208
- out = vocoder(out)
209
-
210
- return out, trajectory
211
-
212
- def forward(
213
- self,
214
- inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722
215
- text: int["b nt"] | list[str], # noqa: F722
216
- *,
217
- lens: int["b"] | None = None, # noqa: F821
218
- noise_scheduler: str | None = None,
219
- ):
220
- # handle raw wave
221
- if inp.ndim == 2:
222
- inp = self.mel_spec(inp)
223
- inp = inp.permute(0, 2, 1)
224
- assert inp.shape[-1] == self.num_channels
225
-
226
- batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
227
-
228
- # handle text as string
229
- if isinstance(text, list):
230
- if exists(self.vocab_char_map):
231
- text = list_str_to_idx(text, self.vocab_char_map).to(device)
232
- else:
233
- text = list_str_to_tensor(text).to(device)
234
- assert text.shape[0] == batch
235
-
236
- # lens and mask
237
- if not exists(lens):
238
- lens = torch.full((batch,), seq_len, device=device)
239
-
240
- mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
241
-
242
- # get a random span to mask out for training conditionally
243
- frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
244
- rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
245
-
246
- if exists(mask):
247
- rand_span_mask &= mask
248
-
249
- # mel is x1
250
- x1 = inp
251
-
252
- # x0 is gaussian noise
253
- x0 = torch.randn_like(x1)
254
-
255
- # time step
256
- time = torch.rand((batch,), dtype=dtype, device=self.device)
257
- # TODO. noise_scheduler
258
-
259
- # sample xt (φ_t(x) in the paper)
260
- t = time.unsqueeze(-1).unsqueeze(-1)
261
- φ = (1 - t) * x0 + t * x1
262
- flow = x1 - x0
263
-
264
- # only predict what is within the random mask span for infilling
265
- cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
266
-
267
- # transformer and cfg training with a drop rate
268
- drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
269
- if random() < self.cond_drop_prob: # p_uncond in voicebox paper
270
- drop_audio_cond = True
271
- drop_text = True
272
- else:
273
- drop_text = False
274
-
275
- # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
276
- # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
277
- pred = self.transformer(
278
- x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text
279
- )
280
-
281
- # flow matching loss
282
- loss = F.mse_loss(pred, flow, reduction="none")
283
- loss = loss[rand_span_mask]
284
-
285
- return loss.mean(), cond, pred
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/f5-tts/model/dataset.py DELETED
@@ -1,314 +0,0 @@
1
- import json
2
- import random
3
- from importlib.resources import files
4
-
5
- import torch
6
- import torch.nn.functional as F
7
- import torchaudio
8
- from datasets import Dataset as Dataset_
9
- from datasets import load_from_disk
10
- from torch import nn
11
- from torch.utils.data import Dataset, Sampler
12
- from tqdm import tqdm
13
-
14
- from f5_tts.model.modules import MelSpec
15
- from f5_tts.model.utils import default
16
-
17
-
18
- class HFDataset(Dataset):
19
- def __init__(
20
- self,
21
- hf_dataset: Dataset,
22
- target_sample_rate=24_000,
23
- n_mel_channels=100,
24
- hop_length=256,
25
- n_fft=1024,
26
- win_length=1024,
27
- mel_spec_type="vocos",
28
- ):
29
- self.data = hf_dataset
30
- self.target_sample_rate = target_sample_rate
31
- self.hop_length = hop_length
32
-
33
- self.mel_spectrogram = MelSpec(
34
- n_fft=n_fft,
35
- hop_length=hop_length,
36
- win_length=win_length,
37
- n_mel_channels=n_mel_channels,
38
- target_sample_rate=target_sample_rate,
39
- mel_spec_type=mel_spec_type,
40
- )
41
-
42
- def get_frame_len(self, index):
43
- row = self.data[index]
44
- audio = row["audio"]["array"]
45
- sample_rate = row["audio"]["sampling_rate"]
46
- return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
47
-
48
- def __len__(self):
49
- return len(self.data)
50
-
51
- def __getitem__(self, index):
52
- row = self.data[index]
53
- audio = row["audio"]["array"]
54
-
55
- # logger.info(f"Audio shape: {audio.shape}")
56
-
57
- sample_rate = row["audio"]["sampling_rate"]
58
- duration = audio.shape[-1] / sample_rate
59
-
60
- if duration > 30 or duration < 0.3:
61
- return self.__getitem__((index + 1) % len(self.data))
62
-
63
- audio_tensor = torch.from_numpy(audio).float()
64
-
65
- if sample_rate != self.target_sample_rate:
66
- resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
67
- audio_tensor = resampler(audio_tensor)
68
-
69
- audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
70
-
71
- mel_spec = self.mel_spectrogram(audio_tensor)
72
-
73
- mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
74
-
75
- text = row["text"]
76
-
77
- return dict(
78
- mel_spec=mel_spec,
79
- text=text,
80
- )
81
-
82
-
83
- class CustomDataset(Dataset):
84
- def __init__(
85
- self,
86
- custom_dataset: Dataset,
87
- durations=None,
88
- target_sample_rate=24_000,
89
- hop_length=256,
90
- n_mel_channels=100,
91
- n_fft=1024,
92
- win_length=1024,
93
- mel_spec_type="vocos",
94
- preprocessed_mel=False,
95
- mel_spec_module: nn.Module | None = None,
96
- ):
97
- self.data = custom_dataset
98
- self.durations = durations
99
- self.target_sample_rate = target_sample_rate
100
- self.hop_length = hop_length
101
- self.n_fft = n_fft
102
- self.win_length = win_length
103
- self.mel_spec_type = mel_spec_type
104
- self.preprocessed_mel = preprocessed_mel
105
-
106
- if not preprocessed_mel:
107
- self.mel_spectrogram = default(
108
- mel_spec_module,
109
- MelSpec(
110
- n_fft=n_fft,
111
- hop_length=hop_length,
112
- win_length=win_length,
113
- n_mel_channels=n_mel_channels,
114
- target_sample_rate=target_sample_rate,
115
- mel_spec_type=mel_spec_type,
116
- ),
117
- )
118
-
119
- def get_frame_len(self, index):
120
- if (
121
- self.durations is not None
122
- ): # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
123
- return self.durations[index] * self.target_sample_rate / self.hop_length
124
- return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
125
-
126
- def __len__(self):
127
- return len(self.data)
128
-
129
- def __getitem__(self, index):
130
- row = self.data[index]
131
- audio_path = row["audio_path"]
132
- text = row["text"]
133
- duration = row["duration"]
134
-
135
- if self.preprocessed_mel:
136
- mel_spec = torch.tensor(row["mel_spec"])
137
-
138
- else:
139
- audio, source_sample_rate = torchaudio.load(audio_path)
140
- if audio.shape[0] > 1:
141
- audio = torch.mean(audio, dim=0, keepdim=True)
142
-
143
- if duration > 30 or duration < 0.3:
144
- return self.__getitem__((index + 1) % len(self.data))
145
-
146
- if source_sample_rate != self.target_sample_rate:
147
- resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
148
- audio = resampler(audio)
149
-
150
- mel_spec = self.mel_spectrogram(audio)
151
- mel_spec = mel_spec.squeeze(0) # '1 d t -> d t')
152
-
153
- return dict(
154
- mel_spec=mel_spec,
155
- text=text,
156
- )
157
-
158
-
159
- # Dynamic Batch Sampler
160
-
161
-
162
- class DynamicBatchSampler(Sampler[list[int]]):
163
- """Extension of Sampler that will do the following:
164
- 1. Change the batch size (essentially number of sequences)
165
- in a batch to ensure that the total number of frames are less
166
- than a certain threshold.
167
- 2. Make sure the padding efficiency in the batch is high.
168
- """
169
-
170
- def __init__(
171
- self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False
172
- ):
173
- self.sampler = sampler
174
- self.frames_threshold = frames_threshold
175
- self.max_samples = max_samples
176
-
177
- indices, batches = [], []
178
- data_source = self.sampler.data_source
179
-
180
- for idx in tqdm(
181
- self.sampler, desc="Sorting with sampler... if slow, check whether dataset is provided with duration"
182
- ):
183
- indices.append((idx, data_source.get_frame_len(idx)))
184
- indices.sort(key=lambda elem: elem[1])
185
-
186
- batch = []
187
- batch_frames = 0
188
- for idx, frame_len in tqdm(
189
- indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"
190
- ):
191
- if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
192
- batch.append(idx)
193
- batch_frames += frame_len
194
- else:
195
- if len(batch) > 0:
196
- batches.append(batch)
197
- if frame_len <= self.frames_threshold:
198
- batch = [idx]
199
- batch_frames = frame_len
200
- else:
201
- batch = []
202
- batch_frames = 0
203
-
204
- if not drop_last and len(batch) > 0:
205
- batches.append(batch)
206
-
207
- del indices
208
-
209
- # if want to have different batches between epochs, may just set a seed and log it in ckpt
210
- # cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different
211
- # e.g. for epoch n, use (random_seed + n)
212
- random.seed(random_seed)
213
- random.shuffle(batches)
214
-
215
- self.batches = batches
216
-
217
- def __iter__(self):
218
- return iter(self.batches)
219
-
220
- def __len__(self):
221
- return len(self.batches)
222
-
223
-
224
- # Load dataset
225
-
226
-
227
- def load_dataset(
228
- dataset_name: str,
229
- tokenizer: str = "pinyin",
230
- dataset_type: str = "CustomDataset",
231
- audio_type: str = "raw",
232
- mel_spec_module: nn.Module | None = None,
233
- mel_spec_kwargs: dict = dict(),
234
- ) -> CustomDataset | HFDataset:
235
- """
236
- dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
237
- - "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
238
- """
239
-
240
- print("Loading dataset ...")
241
-
242
- if dataset_type == "CustomDataset":
243
- rel_data_path = str(files("f5_tts").joinpath(f"../../data/{dataset_name}_{tokenizer}"))
244
- if audio_type == "raw":
245
- try:
246
- train_dataset = load_from_disk(f"{rel_data_path}/raw")
247
- except: # noqa: E722
248
- train_dataset = Dataset_.from_file(f"{rel_data_path}/raw.arrow")
249
- preprocessed_mel = False
250
- elif audio_type == "mel":
251
- train_dataset = Dataset_.from_file(f"{rel_data_path}/mel.arrow")
252
- preprocessed_mel = True
253
- with open(f"{rel_data_path}/duration.json", "r", encoding="utf-8") as f:
254
- data_dict = json.load(f)
255
- durations = data_dict["duration"]
256
- train_dataset = CustomDataset(
257
- train_dataset,
258
- durations=durations,
259
- preprocessed_mel=preprocessed_mel,
260
- mel_spec_module=mel_spec_module,
261
- **mel_spec_kwargs,
262
- )
263
-
264
- elif dataset_type == "CustomDatasetPath":
265
- try:
266
- train_dataset = load_from_disk(f"{dataset_name}/raw")
267
- except: # noqa: E722
268
- train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
269
-
270
- with open(f"{dataset_name}/duration.json", "r", encoding="utf-8") as f:
271
- data_dict = json.load(f)
272
- durations = data_dict["duration"]
273
- train_dataset = CustomDataset(
274
- train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
275
- )
276
-
277
- elif dataset_type == "HFDataset":
278
- print(
279
- "Should manually modify the path of huggingface dataset to your need.\n"
280
- + "May also the corresponding script cuz different dataset may have different format."
281
- )
282
- pre, post = dataset_name.split("_")
283
- train_dataset = HFDataset(
284
- load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir=str(files("f5_tts").joinpath("../../data"))),
285
- )
286
-
287
- return train_dataset
288
-
289
-
290
- # collation
291
-
292
-
293
- def collate_fn(batch):
294
- mel_specs = [item["mel_spec"].squeeze(0) for item in batch]
295
- mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
296
- max_mel_length = mel_lengths.amax()
297
-
298
- padded_mel_specs = []
299
- for spec in mel_specs: # TODO. maybe records mask for attention here
300
- padding = (0, max_mel_length - spec.size(-1))
301
- padded_spec = F.pad(spec, padding, value=0)
302
- padded_mel_specs.append(padded_spec)
303
-
304
- mel_specs = torch.stack(padded_mel_specs)
305
-
306
- text = [item["text"] for item in batch]
307
- text_lengths = torch.LongTensor([len(item) for item in text])
308
-
309
- return dict(
310
- mel=mel_specs,
311
- mel_lengths=mel_lengths,
312
- text=text,
313
- text_lengths=text_lengths,
314
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/f5-tts/model/modules.py DELETED
@@ -1,658 +0,0 @@
1
- """
2
- ein notation:
3
- b - batch
4
- n - sequence
5
- nt - text sequence
6
- nw - raw wave length
7
- d - dimension
8
- """
9
-
10
- from __future__ import annotations
11
-
12
- import math
13
- from typing import Optional
14
-
15
- import torch
16
- import torch.nn.functional as F
17
- import torchaudio
18
- from librosa.filters import mel as librosa_mel_fn
19
- from torch import nn
20
- from x_transformers.x_transformers import apply_rotary_pos_emb
21
-
22
-
23
- # raw wav to mel spec
24
-
25
-
26
- mel_basis_cache = {}
27
- hann_window_cache = {}
28
-
29
-
30
- def get_bigvgan_mel_spectrogram(
31
- waveform,
32
- n_fft=1024,
33
- n_mel_channels=100,
34
- target_sample_rate=24000,
35
- hop_length=256,
36
- win_length=1024,
37
- fmin=0,
38
- fmax=None,
39
- center=False,
40
- ): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main
41
- device = waveform.device
42
- key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
43
-
44
- if key not in mel_basis_cache:
45
- mel = librosa_mel_fn(sr=target_sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=fmin, fmax=fmax)
46
- mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()?
47
- hann_window_cache[key] = torch.hann_window(win_length).to(device)
48
-
49
- mel_basis = mel_basis_cache[key]
50
- hann_window = hann_window_cache[key]
51
-
52
- padding = (n_fft - hop_length) // 2
53
- waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
54
-
55
- spec = torch.stft(
56
- waveform,
57
- n_fft,
58
- hop_length=hop_length,
59
- win_length=win_length,
60
- window=hann_window,
61
- center=center,
62
- pad_mode="reflect",
63
- normalized=False,
64
- onesided=True,
65
- return_complex=True,
66
- )
67
- spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
68
-
69
- mel_spec = torch.matmul(mel_basis, spec)
70
- mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
71
-
72
- return mel_spec
73
-
74
-
75
- def get_vocos_mel_spectrogram(
76
- waveform,
77
- n_fft=1024,
78
- n_mel_channels=100,
79
- target_sample_rate=24000,
80
- hop_length=256,
81
- win_length=1024,
82
- ):
83
- mel_stft = torchaudio.transforms.MelSpectrogram(
84
- sample_rate=target_sample_rate,
85
- n_fft=n_fft,
86
- win_length=win_length,
87
- hop_length=hop_length,
88
- n_mels=n_mel_channels,
89
- power=1,
90
- center=True,
91
- normalized=False,
92
- norm=None,
93
- ).to(waveform.device)
94
- if len(waveform.shape) == 3:
95
- waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
96
-
97
- assert len(waveform.shape) == 2
98
-
99
- mel = mel_stft(waveform)
100
- mel = mel.clamp(min=1e-5).log()
101
- return mel
102
-
103
-
104
- class MelSpec(nn.Module):
105
- def __init__(
106
- self,
107
- n_fft=1024,
108
- hop_length=256,
109
- win_length=1024,
110
- n_mel_channels=100,
111
- target_sample_rate=24_000,
112
- mel_spec_type="vocos",
113
- ):
114
- super().__init__()
115
- assert mel_spec_type in ["vocos", "bigvgan"], print("We only support two extract mel backend: vocos or bigvgan")
116
-
117
- self.n_fft = n_fft
118
- self.hop_length = hop_length
119
- self.win_length = win_length
120
- self.n_mel_channels = n_mel_channels
121
- self.target_sample_rate = target_sample_rate
122
-
123
- if mel_spec_type == "vocos":
124
- self.extractor = get_vocos_mel_spectrogram
125
- elif mel_spec_type == "bigvgan":
126
- self.extractor = get_bigvgan_mel_spectrogram
127
-
128
- self.register_buffer("dummy", torch.tensor(0), persistent=False)
129
-
130
- def forward(self, wav):
131
- if self.dummy.device != wav.device:
132
- self.to(wav.device)
133
-
134
- mel = self.extractor(
135
- waveform=wav,
136
- n_fft=self.n_fft,
137
- n_mel_channels=self.n_mel_channels,
138
- target_sample_rate=self.target_sample_rate,
139
- hop_length=self.hop_length,
140
- win_length=self.win_length,
141
- )
142
-
143
- return mel
144
-
145
-
146
- # sinusoidal position embedding
147
-
148
-
149
- class SinusPositionEmbedding(nn.Module):
150
- def __init__(self, dim):
151
- super().__init__()
152
- self.dim = dim
153
-
154
- def forward(self, x, scale=1000):
155
- device = x.device
156
- half_dim = self.dim // 2
157
- emb = math.log(10000) / (half_dim - 1)
158
- emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
159
- emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
160
- emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
161
- return emb
162
-
163
-
164
- # convolutional position embedding
165
-
166
-
167
- class ConvPositionEmbedding(nn.Module):
168
- def __init__(self, dim, kernel_size=31, groups=16):
169
- super().__init__()
170
- assert kernel_size % 2 != 0
171
- self.conv1d = nn.Sequential(
172
- nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
173
- nn.Mish(),
174
- nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
175
- nn.Mish(),
176
- )
177
-
178
- def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
179
- if mask is not None:
180
- mask = mask[..., None]
181
- x = x.masked_fill(~mask, 0.0)
182
-
183
- x = x.permute(0, 2, 1)
184
- x = self.conv1d(x)
185
- out = x.permute(0, 2, 1)
186
-
187
- if mask is not None:
188
- out = out.masked_fill(~mask, 0.0)
189
-
190
- return out
191
-
192
-
193
- # rotary positional embedding related
194
-
195
-
196
- def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
197
- # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
198
- # has some connection to NTK literature
199
- # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
200
- # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
201
- theta *= theta_rescale_factor ** (dim / (dim - 2))
202
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
203
- t = torch.arange(end, device=freqs.device) # type: ignore
204
- freqs = torch.outer(t, freqs).float() # type: ignore
205
- freqs_cos = torch.cos(freqs) # real part
206
- freqs_sin = torch.sin(freqs) # imaginary part
207
- return torch.cat([freqs_cos, freqs_sin], dim=-1)
208
-
209
-
210
- def get_pos_embed_indices(start, length, max_pos, scale=1.0):
211
- # length = length if isinstance(length, int) else length.max()
212
- scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
213
- pos = (
214
- start.unsqueeze(1)
215
- + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
216
- )
217
- # avoid extra long error.
218
- pos = torch.where(pos < max_pos, pos, max_pos - 1)
219
- return pos
220
-
221
-
222
- # Global Response Normalization layer (Instance Normalization ?)
223
-
224
-
225
- class GRN(nn.Module):
226
- def __init__(self, dim):
227
- super().__init__()
228
- self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
229
- self.beta = nn.Parameter(torch.zeros(1, 1, dim))
230
-
231
- def forward(self, x):
232
- Gx = torch.norm(x, p=2, dim=1, keepdim=True)
233
- Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
234
- return self.gamma * (x * Nx) + self.beta + x
235
-
236
-
237
- # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
238
- # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
239
-
240
-
241
- class ConvNeXtV2Block(nn.Module):
242
- def __init__(
243
- self,
244
- dim: int,
245
- intermediate_dim: int,
246
- dilation: int = 1,
247
- ):
248
- super().__init__()
249
- padding = (dilation * (7 - 1)) // 2
250
- self.dwconv = nn.Conv1d(
251
- dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
252
- ) # depthwise conv
253
- self.norm = nn.LayerNorm(dim, eps=1e-6)
254
- self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
255
- self.act = nn.GELU()
256
- self.grn = GRN(intermediate_dim)
257
- self.pwconv2 = nn.Linear(intermediate_dim, dim)
258
-
259
- def forward(self, x: torch.Tensor) -> torch.Tensor:
260
- residual = x
261
- x = x.transpose(1, 2) # b n d -> b d n
262
- x = self.dwconv(x)
263
- x = x.transpose(1, 2) # b d n -> b n d
264
- x = self.norm(x)
265
- x = self.pwconv1(x)
266
- x = self.act(x)
267
- x = self.grn(x)
268
- x = self.pwconv2(x)
269
- return residual + x
270
-
271
-
272
- # AdaLayerNormZero
273
- # return with modulated x for attn input, and params for later mlp modulation
274
-
275
-
276
- class AdaLayerNormZero(nn.Module):
277
- def __init__(self, dim):
278
- super().__init__()
279
-
280
- self.silu = nn.SiLU()
281
- self.linear = nn.Linear(dim, dim * 6)
282
-
283
- self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
284
-
285
- def forward(self, x, emb=None):
286
- emb = self.linear(self.silu(emb))
287
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
288
-
289
- x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
290
- return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
291
-
292
-
293
- # AdaLayerNormZero for final layer
294
- # return only with modulated x for attn input, cuz no more mlp modulation
295
-
296
-
297
- class AdaLayerNormZero_Final(nn.Module):
298
- def __init__(self, dim):
299
- super().__init__()
300
-
301
- self.silu = nn.SiLU()
302
- self.linear = nn.Linear(dim, dim * 2)
303
-
304
- self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
305
-
306
- def forward(self, x, emb):
307
- emb = self.linear(self.silu(emb))
308
- scale, shift = torch.chunk(emb, 2, dim=1)
309
-
310
- x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
311
- return x
312
-
313
-
314
- # FeedForward
315
-
316
-
317
- class FeedForward(nn.Module):
318
- def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
319
- super().__init__()
320
- inner_dim = int(dim * mult)
321
- dim_out = dim_out if dim_out is not None else dim
322
-
323
- activation = nn.GELU(approximate=approximate)
324
- project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
325
- self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
326
-
327
- def forward(self, x):
328
- return self.ff(x)
329
-
330
-
331
- # Attention with possible joint part
332
- # modified from diffusers/src/diffusers/models/attention_processor.py
333
-
334
-
335
- class Attention(nn.Module):
336
- def __init__(
337
- self,
338
- processor: JointAttnProcessor | AttnProcessor,
339
- dim: int,
340
- heads: int = 8,
341
- dim_head: int = 64,
342
- dropout: float = 0.0,
343
- context_dim: Optional[int] = None, # if not None -> joint attention
344
- context_pre_only=None,
345
- ):
346
- super().__init__()
347
-
348
- if not hasattr(F, "scaled_dot_product_attention"):
349
- raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
350
-
351
- self.processor = processor
352
-
353
- self.dim = dim
354
- self.heads = heads
355
- self.inner_dim = dim_head * heads
356
- self.dropout = dropout
357
-
358
- self.context_dim = context_dim
359
- self.context_pre_only = context_pre_only
360
-
361
- self.to_q = nn.Linear(dim, self.inner_dim)
362
- self.to_k = nn.Linear(dim, self.inner_dim)
363
- self.to_v = nn.Linear(dim, self.inner_dim)
364
-
365
- if self.context_dim is not None:
366
- self.to_k_c = nn.Linear(context_dim, self.inner_dim)
367
- self.to_v_c = nn.Linear(context_dim, self.inner_dim)
368
- if self.context_pre_only is not None:
369
- self.to_q_c = nn.Linear(context_dim, self.inner_dim)
370
-
371
- self.to_out = nn.ModuleList([])
372
- self.to_out.append(nn.Linear(self.inner_dim, dim))
373
- self.to_out.append(nn.Dropout(dropout))
374
-
375
- if self.context_pre_only is not None and not self.context_pre_only:
376
- self.to_out_c = nn.Linear(self.inner_dim, dim)
377
-
378
- def forward(
379
- self,
380
- x: float["b n d"], # noised input x # noqa: F722
381
- c: float["b n d"] = None, # context c # noqa: F722
382
- mask: bool["b n"] | None = None, # noqa: F722
383
- rope=None, # rotary position embedding for x
384
- c_rope=None, # rotary position embedding for c
385
- ) -> torch.Tensor:
386
- if c is not None:
387
- return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
388
- else:
389
- return self.processor(self, x, mask=mask, rope=rope)
390
-
391
-
392
- # Attention processor
393
-
394
-
395
- class AttnProcessor:
396
- def __init__(self):
397
- pass
398
-
399
- def __call__(
400
- self,
401
- attn: Attention,
402
- x: float["b n d"], # noised input x # noqa: F722
403
- mask: bool["b n"] | None = None, # noqa: F722
404
- rope=None, # rotary position embedding
405
- ) -> torch.FloatTensor:
406
- batch_size = x.shape[0]
407
-
408
- # `sample` projections.
409
- query = attn.to_q(x)
410
- key = attn.to_k(x)
411
- value = attn.to_v(x)
412
-
413
- # apply rotary position embedding
414
- if rope is not None:
415
- freqs, xpos_scale = rope
416
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
417
-
418
- query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
419
- key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
420
-
421
- # attention
422
- inner_dim = key.shape[-1]
423
- head_dim = inner_dim // attn.heads
424
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
425
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
426
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
427
-
428
- # mask. e.g. inference got a batch with different target durations, mask out the padding
429
- if mask is not None:
430
- attn_mask = mask
431
- attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
432
- attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
433
- else:
434
- attn_mask = None
435
-
436
- x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
437
- x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
438
- x = x.to(query.dtype)
439
-
440
- # linear proj
441
- x = attn.to_out[0](x)
442
- # dropout
443
- x = attn.to_out[1](x)
444
-
445
- if mask is not None:
446
- mask = mask.unsqueeze(-1)
447
- x = x.masked_fill(~mask, 0.0)
448
-
449
- return x
450
-
451
-
452
- # Joint Attention processor for MM-DiT
453
- # modified from diffusers/src/diffusers/models/attention_processor.py
454
-
455
-
456
- class JointAttnProcessor:
457
- def __init__(self):
458
- pass
459
-
460
- def __call__(
461
- self,
462
- attn: Attention,
463
- x: float["b n d"], # noised input x # noqa: F722
464
- c: float["b nt d"] = None, # context c, here text # noqa: F722
465
- mask: bool["b n"] | None = None, # noqa: F722
466
- rope=None, # rotary position embedding for x
467
- c_rope=None, # rotary position embedding for c
468
- ) -> torch.FloatTensor:
469
- residual = x
470
-
471
- batch_size = c.shape[0]
472
-
473
- # `sample` projections.
474
- query = attn.to_q(x)
475
- key = attn.to_k(x)
476
- value = attn.to_v(x)
477
-
478
- # `context` projections.
479
- c_query = attn.to_q_c(c)
480
- c_key = attn.to_k_c(c)
481
- c_value = attn.to_v_c(c)
482
-
483
- # apply rope for context and noised input independently
484
- if rope is not None:
485
- freqs, xpos_scale = rope
486
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
487
- query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
488
- key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
489
- if c_rope is not None:
490
- freqs, xpos_scale = c_rope
491
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
492
- c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
493
- c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
494
-
495
- # attention
496
- query = torch.cat([query, c_query], dim=1)
497
- key = torch.cat([key, c_key], dim=1)
498
- value = torch.cat([value, c_value], dim=1)
499
-
500
- inner_dim = key.shape[-1]
501
- head_dim = inner_dim // attn.heads
502
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
503
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
504
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
505
-
506
- # mask. e.g. inference got a batch with different target durations, mask out the padding
507
- if mask is not None:
508
- attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
509
- attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
510
- attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
511
- else:
512
- attn_mask = None
513
-
514
- x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
515
- x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
516
- x = x.to(query.dtype)
517
-
518
- # Split the attention outputs.
519
- x, c = (
520
- x[:, : residual.shape[1]],
521
- x[:, residual.shape[1] :],
522
- )
523
-
524
- # linear proj
525
- x = attn.to_out[0](x)
526
- # dropout
527
- x = attn.to_out[1](x)
528
- if not attn.context_pre_only:
529
- c = attn.to_out_c(c)
530
-
531
- if mask is not None:
532
- mask = mask.unsqueeze(-1)
533
- x = x.masked_fill(~mask, 0.0)
534
- # c = c.masked_fill(~mask, 0.) # no mask for c (text)
535
-
536
- return x, c
537
-
538
-
539
- # DiT Block
540
-
541
-
542
- class DiTBlock(nn.Module):
543
- def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
544
- super().__init__()
545
-
546
- self.attn_norm = AdaLayerNormZero(dim)
547
- self.attn = Attention(
548
- processor=AttnProcessor(),
549
- dim=dim,
550
- heads=heads,
551
- dim_head=dim_head,
552
- dropout=dropout,
553
- )
554
-
555
- self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
556
- self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
557
-
558
- def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
559
- # pre-norm & modulation for attention input
560
- norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
561
-
562
- # attention
563
- attn_output = self.attn(x=norm, mask=mask, rope=rope)
564
-
565
- # process attention output for input x
566
- x = x + gate_msa.unsqueeze(1) * attn_output
567
-
568
- norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
569
- ff_output = self.ff(norm)
570
- x = x + gate_mlp.unsqueeze(1) * ff_output
571
-
572
- return x
573
-
574
-
575
- # MMDiT Block https://arxiv.org/abs/2403.03206
576
-
577
-
578
- class MMDiTBlock(nn.Module):
579
- r"""
580
- modified from diffusers/src/diffusers/models/attention.py
581
-
582
- notes.
583
- _c: context related. text, cond, etc. (left part in sd3 fig2.b)
584
- _x: noised input related. (right part)
585
- context_pre_only: last layer only do prenorm + modulation cuz no more ffn
586
- """
587
-
588
- def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
589
- super().__init__()
590
-
591
- self.context_pre_only = context_pre_only
592
-
593
- self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
594
- self.attn_norm_x = AdaLayerNormZero(dim)
595
- self.attn = Attention(
596
- processor=JointAttnProcessor(),
597
- dim=dim,
598
- heads=heads,
599
- dim_head=dim_head,
600
- dropout=dropout,
601
- context_dim=dim,
602
- context_pre_only=context_pre_only,
603
- )
604
-
605
- if not context_pre_only:
606
- self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
607
- self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
608
- else:
609
- self.ff_norm_c = None
610
- self.ff_c = None
611
- self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
612
- self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
613
-
614
- def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
615
- # pre-norm & modulation for attention input
616
- if self.context_pre_only:
617
- norm_c = self.attn_norm_c(c, t)
618
- else:
619
- norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
620
- norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
621
-
622
- # attention
623
- x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
624
-
625
- # process attention output for context c
626
- if self.context_pre_only:
627
- c = None
628
- else: # if not last layer
629
- c = c + c_gate_msa.unsqueeze(1) * c_attn_output
630
-
631
- norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
632
- c_ff_output = self.ff_c(norm_c)
633
- c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
634
-
635
- # process attention output for input x
636
- x = x + x_gate_msa.unsqueeze(1) * x_attn_output
637
-
638
- norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
639
- x_ff_output = self.ff_x(norm_x)
640
- x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
641
-
642
- return c, x
643
-
644
-
645
- # time step conditioning embedding
646
-
647
-
648
- class TimestepEmbedding(nn.Module):
649
- def __init__(self, dim, freq_embed_dim=256):
650
- super().__init__()
651
- self.time_embed = SinusPositionEmbedding(freq_embed_dim)
652
- self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
653
-
654
- def forward(self, timestep: float["b"]): # noqa: F821
655
- time_hidden = self.time_embed(timestep)
656
- time_hidden = time_hidden.to(timestep.dtype)
657
- time = self.time_mlp(time_hidden) # b d
658
- return time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/f5-tts/model/trainer.py DELETED
@@ -1,353 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import gc
4
- import os
5
-
6
- import torch
7
- import torchaudio
8
- import wandb
9
- from accelerate import Accelerator
10
- from accelerate.utils import DistributedDataParallelKwargs
11
- from ema_pytorch import EMA
12
- from torch.optim import AdamW
13
- from torch.optim.lr_scheduler import LinearLR, SequentialLR
14
- from torch.utils.data import DataLoader, Dataset, SequentialSampler
15
- from tqdm import tqdm
16
-
17
- from f5_tts.model import CFM
18
- from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
19
- from f5_tts.model.utils import default, exists
20
-
21
- # trainer
22
-
23
-
24
- class Trainer:
25
- def __init__(
26
- self,
27
- model: CFM,
28
- epochs,
29
- learning_rate,
30
- num_warmup_updates=20000,
31
- save_per_updates=1000,
32
- checkpoint_path=None,
33
- batch_size=32,
34
- batch_size_type: str = "sample",
35
- max_samples=32,
36
- grad_accumulation_steps=1,
37
- max_grad_norm=1.0,
38
- noise_scheduler: str | None = None,
39
- duration_predictor: torch.nn.Module | None = None,
40
- logger: str | None = "wandb", # "wandb" | "tensorboard" | None
41
- wandb_project="test_e2-tts",
42
- wandb_run_name="test_run",
43
- wandb_resume_id: str = None,
44
- log_samples: bool = False,
45
- last_per_steps=None,
46
- accelerate_kwargs: dict = dict(),
47
- ema_kwargs: dict = dict(),
48
- bnb_optimizer: bool = False,
49
- mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
50
- ):
51
- ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
52
-
53
- if logger == "wandb" and not wandb.api.api_key:
54
- logger = None
55
- print(f"Using logger: {logger}")
56
- self.log_samples = log_samples
57
-
58
- self.accelerator = Accelerator(
59
- log_with=logger if logger == "wandb" else None,
60
- kwargs_handlers=[ddp_kwargs],
61
- gradient_accumulation_steps=grad_accumulation_steps,
62
- **accelerate_kwargs,
63
- )
64
-
65
- self.logger = logger
66
- if self.logger == "wandb":
67
- if exists(wandb_resume_id):
68
- init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
69
- else:
70
- init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
71
-
72
- self.accelerator.init_trackers(
73
- project_name=wandb_project,
74
- init_kwargs=init_kwargs,
75
- config={
76
- "epochs": epochs,
77
- "learning_rate": learning_rate,
78
- "num_warmup_updates": num_warmup_updates,
79
- "batch_size": batch_size,
80
- "batch_size_type": batch_size_type,
81
- "max_samples": max_samples,
82
- "grad_accumulation_steps": grad_accumulation_steps,
83
- "max_grad_norm": max_grad_norm,
84
- "gpus": self.accelerator.num_processes,
85
- "noise_scheduler": noise_scheduler,
86
- },
87
- )
88
-
89
- elif self.logger == "tensorboard":
90
- from torch.utils.tensorboard import SummaryWriter
91
-
92
- self.writer = SummaryWriter(log_dir=f"runs/{wandb_run_name}")
93
-
94
- self.model = model
95
-
96
- if self.is_main:
97
- self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
98
- self.ema_model.to(self.accelerator.device)
99
-
100
- self.epochs = epochs
101
- self.num_warmup_updates = num_warmup_updates
102
- self.save_per_updates = save_per_updates
103
- self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
104
- self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
105
-
106
- self.batch_size = batch_size
107
- self.batch_size_type = batch_size_type
108
- self.max_samples = max_samples
109
- self.grad_accumulation_steps = grad_accumulation_steps
110
- self.max_grad_norm = max_grad_norm
111
- self.vocoder_name = mel_spec_type
112
-
113
- self.noise_scheduler = noise_scheduler
114
-
115
- self.duration_predictor = duration_predictor
116
-
117
- if bnb_optimizer:
118
- import bitsandbytes as bnb
119
-
120
- self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
121
- else:
122
- self.optimizer = AdamW(model.parameters(), lr=learning_rate)
123
- self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
124
-
125
- @property
126
- def is_main(self):
127
- return self.accelerator.is_main_process
128
-
129
- def save_checkpoint(self, step, last=False):
130
- self.accelerator.wait_for_everyone()
131
- if self.is_main:
132
- checkpoint = dict(
133
- model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
134
- optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
135
- ema_model_state_dict=self.ema_model.state_dict(),
136
- scheduler_state_dict=self.scheduler.state_dict(),
137
- step=step,
138
- )
139
- if not os.path.exists(self.checkpoint_path):
140
- os.makedirs(self.checkpoint_path)
141
- if last:
142
- self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
143
- print(f"Saved last checkpoint at step {step}")
144
- else:
145
- self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
146
-
147
- def load_checkpoint(self):
148
- if (
149
- not exists(self.checkpoint_path)
150
- or not os.path.exists(self.checkpoint_path)
151
- or not os.listdir(self.checkpoint_path)
152
- ):
153
- return 0
154
-
155
- self.accelerator.wait_for_everyone()
156
- if "model_last.pt" in os.listdir(self.checkpoint_path):
157
- latest_checkpoint = "model_last.pt"
158
- else:
159
- latest_checkpoint = sorted(
160
- [f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")],
161
- key=lambda x: int("".join(filter(str.isdigit, x))),
162
- )[-1]
163
- # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
164
- checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
165
-
166
- # patch for backward compatibility, 305e3ea
167
- for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
168
- if key in checkpoint["ema_model_state_dict"]:
169
- del checkpoint["ema_model_state_dict"][key]
170
-
171
- if self.is_main:
172
- self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
173
-
174
- if "step" in checkpoint:
175
- # patch for backward compatibility, 305e3ea
176
- for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
177
- if key in checkpoint["model_state_dict"]:
178
- del checkpoint["model_state_dict"][key]
179
-
180
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
181
- self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
182
- if self.scheduler:
183
- self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
184
- step = checkpoint["step"]
185
- else:
186
- checkpoint["model_state_dict"] = {
187
- k.replace("ema_model.", ""): v
188
- for k, v in checkpoint["ema_model_state_dict"].items()
189
- if k not in ["initted", "step"]
190
- }
191
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
192
- step = 0
193
-
194
- del checkpoint
195
- gc.collect()
196
- return step
197
-
198
- def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
199
- if self.log_samples:
200
- from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef
201
-
202
- vocoder = load_vocoder(vocoder_name=self.vocoder_name)
203
- target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.target_sample_rate
204
- log_samples_path = f"{self.checkpoint_path}/samples"
205
- os.makedirs(log_samples_path, exist_ok=True)
206
-
207
- if exists(resumable_with_seed):
208
- generator = torch.Generator()
209
- generator.manual_seed(resumable_with_seed)
210
- else:
211
- generator = None
212
-
213
- if self.batch_size_type == "sample":
214
- train_dataloader = DataLoader(
215
- train_dataset,
216
- collate_fn=collate_fn,
217
- num_workers=num_workers,
218
- pin_memory=True,
219
- persistent_workers=True,
220
- batch_size=self.batch_size,
221
- shuffle=True,
222
- generator=generator,
223
- )
224
- elif self.batch_size_type == "frame":
225
- self.accelerator.even_batches = False
226
- sampler = SequentialSampler(train_dataset)
227
- batch_sampler = DynamicBatchSampler(
228
- sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False
229
- )
230
- train_dataloader = DataLoader(
231
- train_dataset,
232
- collate_fn=collate_fn,
233
- num_workers=num_workers,
234
- pin_memory=True,
235
- persistent_workers=True,
236
- batch_sampler=batch_sampler,
237
- )
238
- else:
239
- raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
240
-
241
- # accelerator.prepare() dispatches batches to devices;
242
- # which means the length of dataloader calculated before, should consider the number of devices
243
- warmup_steps = (
244
- self.num_warmup_updates * self.accelerator.num_processes
245
- ) # consider a fixed warmup steps while using accelerate multi-gpu ddp
246
- # otherwise by default with split_batches=False, warmup steps change with num_processes
247
- total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
248
- decay_steps = total_steps - warmup_steps
249
- warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
250
- decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
251
- self.scheduler = SequentialLR(
252
- self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
253
- )
254
- train_dataloader, self.scheduler = self.accelerator.prepare(
255
- train_dataloader, self.scheduler
256
- ) # actual steps = 1 gpu steps / gpus
257
- start_step = self.load_checkpoint()
258
- global_step = start_step
259
-
260
- if exists(resumable_with_seed):
261
- orig_epoch_step = len(train_dataloader)
262
- skipped_epoch = int(start_step // orig_epoch_step)
263
- skipped_batch = start_step % orig_epoch_step
264
- skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
265
- else:
266
- skipped_epoch = 0
267
-
268
- for epoch in range(skipped_epoch, self.epochs):
269
- self.model.train()
270
- if exists(resumable_with_seed) and epoch == skipped_epoch:
271
- progress_bar = tqdm(
272
- skipped_dataloader,
273
- desc=f"Epoch {epoch+1}/{self.epochs}",
274
- unit="step",
275
- disable=not self.accelerator.is_local_main_process,
276
- initial=skipped_batch,
277
- total=orig_epoch_step,
278
- )
279
- else:
280
- progress_bar = tqdm(
281
- train_dataloader,
282
- desc=f"Epoch {epoch+1}/{self.epochs}",
283
- unit="step",
284
- disable=not self.accelerator.is_local_main_process,
285
- )
286
-
287
- for batch in progress_bar:
288
- with self.accelerator.accumulate(self.model):
289
- text_inputs = batch["text"]
290
- mel_spec = batch["mel"].permute(0, 2, 1)
291
- mel_lengths = batch["mel_lengths"]
292
-
293
- # TODO. add duration predictor training
294
- if self.duration_predictor is not None and self.accelerator.is_local_main_process:
295
- dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
296
- self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
297
-
298
- loss, cond, pred = self.model(
299
- mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
300
- )
301
- self.accelerator.backward(loss)
302
-
303
- if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
304
- self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
305
-
306
- self.optimizer.step()
307
- self.scheduler.step()
308
- self.optimizer.zero_grad()
309
-
310
- if self.is_main:
311
- self.ema_model.update()
312
-
313
- global_step += 1
314
-
315
- if self.accelerator.is_local_main_process:
316
- self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
317
- if self.logger == "tensorboard":
318
- self.writer.add_scalar("loss", loss.item(), global_step)
319
- self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_step)
320
-
321
- progress_bar.set_postfix(step=str(global_step), loss=loss.item())
322
-
323
- if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
324
- self.save_checkpoint(global_step)
325
-
326
- if self.log_samples and self.accelerator.is_local_main_process:
327
- ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0)), mel_lengths[0]
328
- torchaudio.save(
329
- f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio.cpu(), target_sample_rate
330
- )
331
- with torch.inference_mode():
332
- generated, _ = self.accelerator.unwrap_model(self.model).sample(
333
- cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
334
- text=[text_inputs[0] + [" "] + text_inputs[0]],
335
- duration=ref_audio_len * 2,
336
- steps=nfe_step,
337
- cfg_strength=cfg_strength,
338
- sway_sampling_coef=sway_sampling_coef,
339
- )
340
- generated = generated.to(torch.float32)
341
- gen_audio = vocoder.decode(
342
- generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
343
- )
344
- torchaudio.save(
345
- f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio.cpu(), target_sample_rate
346
- )
347
-
348
- if global_step % self.last_per_steps == 0:
349
- self.save_checkpoint(global_step, last=True)
350
-
351
- self.save_checkpoint(global_step, last=True)
352
-
353
- self.accelerator.end_training()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/f5-tts/model/utils.py DELETED
@@ -1,185 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import os
4
- import random
5
- from collections import defaultdict
6
- from importlib.resources import files
7
-
8
- import torch
9
- from torch.nn.utils.rnn import pad_sequence
10
-
11
- import jieba
12
- from pypinyin import lazy_pinyin, Style
13
-
14
-
15
- # seed everything
16
-
17
-
18
- def seed_everything(seed=0):
19
- random.seed(seed)
20
- os.environ["PYTHONHASHSEED"] = str(seed)
21
- torch.manual_seed(seed)
22
- torch.cuda.manual_seed(seed)
23
- torch.cuda.manual_seed_all(seed)
24
- torch.backends.cudnn.deterministic = True
25
- torch.backends.cudnn.benchmark = False
26
-
27
-
28
- # helpers
29
-
30
-
31
- def exists(v):
32
- return v is not None
33
-
34
-
35
- def default(v, d):
36
- return v if exists(v) else d
37
-
38
-
39
- # tensor helpers
40
-
41
-
42
- def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
43
- if not exists(length):
44
- length = t.amax()
45
-
46
- seq = torch.arange(length, device=t.device)
47
- return seq[None, :] < t[:, None]
48
-
49
-
50
- def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821
51
- max_seq_len = seq_len.max().item()
52
- seq = torch.arange(max_seq_len, device=start.device).long()
53
- start_mask = seq[None, :] >= start[:, None]
54
- end_mask = seq[None, :] < end[:, None]
55
- return start_mask & end_mask
56
-
57
-
58
- def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821
59
- lengths = (frac_lengths * seq_len).long()
60
- max_start = seq_len - lengths
61
-
62
- rand = torch.rand_like(frac_lengths)
63
- start = (max_start * rand).long().clamp(min=0)
64
- end = start + lengths
65
-
66
- return mask_from_start_end_indices(seq_len, start, end)
67
-
68
-
69
- def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
70
- if not exists(mask):
71
- return t.mean(dim=1)
72
-
73
- t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device))
74
- num = t.sum(dim=1)
75
- den = mask.float().sum(dim=1)
76
-
77
- return num / den.clamp(min=1.0)
78
-
79
-
80
- # simple utf-8 tokenizer, since paper went character based
81
- def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
82
- list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
83
- text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
84
- return text
85
-
86
-
87
- # char tokenizer, based on custom dataset's extracted .txt file
88
- def list_str_to_idx(
89
- text: list[str] | list[list[str]],
90
- vocab_char_map: dict[str, int], # {char: idx}
91
- padding_value=-1,
92
- ) -> int["b nt"]: # noqa: F722
93
- list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
94
- text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
95
- return text
96
-
97
-
98
- # Get tokenizer
99
-
100
-
101
- def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
102
- """
103
- tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
104
- - "char" for char-wise tokenizer, need .txt vocab_file
105
- - "byte" for utf-8 tokenizer
106
- - "custom" if you're directly passing in a path to the vocab.txt you want to use
107
- vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
108
- - if use "char", derived from unfiltered character & symbol counts of custom dataset
109
- - if use "byte", set to 256 (unicode byte range)
110
- """
111
- if tokenizer in ["pinyin", "char"]:
112
- tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt")
113
- with open(tokenizer_path, "r", encoding="utf-8") as f:
114
- vocab_char_map = {}
115
- for i, char in enumerate(f):
116
- vocab_char_map[char[:-1]] = i
117
- vocab_size = len(vocab_char_map)
118
- assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
119
-
120
- elif tokenizer == "byte":
121
- vocab_char_map = None
122
- vocab_size = 256
123
-
124
- elif tokenizer == "custom":
125
- with open(dataset_name, "r", encoding="utf-8") as f:
126
- vocab_char_map = {}
127
- for i, char in enumerate(f):
128
- vocab_char_map[char[:-1]] = i
129
- vocab_size = len(vocab_char_map)
130
-
131
- return vocab_char_map, vocab_size
132
-
133
-
134
- # convert char to pinyin
135
-
136
-
137
- def convert_char_to_pinyin(text_list, polyphone=True):
138
- final_text_list = []
139
- god_knows_why_en_testset_contains_zh_quote = str.maketrans(
140
- {"“": '"', "”": '"', "‘": "'", "’": "'"}
141
- ) # in case librispeech (orig no-pc) test-clean
142
- custom_trans = str.maketrans({";": ","}) # add custom trans here, to address oov
143
- for text in text_list:
144
- char_list = []
145
- text = text.translate(god_knows_why_en_testset_contains_zh_quote)
146
- text = text.translate(custom_trans)
147
- for seg in jieba.cut(text):
148
- seg_byte_len = len(bytes(seg, "UTF-8"))
149
- if seg_byte_len == len(seg): # if pure alphabets and symbols
150
- if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
151
- char_list.append(" ")
152
- char_list.extend(seg)
153
- elif polyphone and seg_byte_len == 3 * len(seg): # if pure chinese characters
154
- seg = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
155
- for c in seg:
156
- if c not in "。,、;:?!《》【】—…":
157
- char_list.append(" ")
158
- char_list.append(c)
159
- else: # if mixed chinese characters, alphabets and symbols
160
- for c in seg:
161
- if ord(c) < 256:
162
- char_list.extend(c)
163
- else:
164
- if c not in "。,、;:?!《》【】—…":
165
- char_list.append(" ")
166
- char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
167
- else: # if is zh punc
168
- char_list.append(c)
169
- final_text_list.append(char_list)
170
-
171
- return final_text_list
172
-
173
-
174
- # filter func for dirty data with many repetitions
175
-
176
-
177
- def repetition_found(text, length=2, tolerance=10):
178
- pattern_count = defaultdict(int)
179
- for i in range(len(text) - length + 1):
180
- pattern = text[i : i + length]
181
- pattern_count[pattern] += 1
182
- for pattern, count in pattern_count.items():
183
- if count > tolerance:
184
- return True
185
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/f5-tts/socket.py DELETED
@@ -1,159 +0,0 @@
1
- import socket
2
- import struct
3
- import torch
4
- import torchaudio
5
- from threading import Thread
6
-
7
-
8
- import gc
9
- import traceback
10
-
11
-
12
- from infer.utils_infer import infer_batch_process, preprocess_ref_audio_text, load_vocoder, load_model
13
- from model.backbones.dit import DiT
14
-
15
-
16
- class TTSStreamingProcessor:
17
- def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
18
- self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
19
-
20
- # Load the model using the provided checkpoint and vocab files
21
- self.model = load_model(
22
- DiT,
23
- dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
24
- ckpt_file,
25
- vocab_file,
26
- ).to(self.device, dtype=dtype)
27
-
28
- # Load the vocoder
29
- self.vocoder = load_vocoder(is_local=False)
30
-
31
- # Set sampling rate for streaming
32
- self.sampling_rate = 24000 # Consistency with client
33
-
34
- # Set reference audio and text
35
- self.ref_audio = ref_audio
36
- self.ref_text = ref_text
37
-
38
- # Warm up the model
39
- self._warm_up()
40
-
41
- def _warm_up(self):
42
- """Warm up the model with a dummy input to ensure it's ready for real-time processing."""
43
- print("Warming up the model...")
44
- ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
45
- audio, sr = torchaudio.load(ref_audio)
46
- gen_text = "Warm-up text for the model."
47
-
48
- # Pass the vocoder as an argument here
49
- infer_batch_process((audio, sr), ref_text, [gen_text], self.model, self.vocoder, device=self.device)
50
- print("Warm-up completed.")
51
-
52
- def generate_stream(self, text, play_steps_in_s=0.5):
53
- """Generate audio in chunks and yield them in real-time."""
54
- # Preprocess the reference audio and text
55
- ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
56
-
57
- # Load reference audio
58
- audio, sr = torchaudio.load(ref_audio)
59
-
60
- # Run inference for the input text
61
- audio_chunk, final_sample_rate, _ = infer_batch_process(
62
- (audio, sr),
63
- ref_text,
64
- [text],
65
- self.model,
66
- self.vocoder,
67
- device=self.device, # Pass vocoder here
68
- )
69
-
70
- # Break the generated audio into chunks and send them
71
- chunk_size = int(final_sample_rate * play_steps_in_s)
72
-
73
- for i in range(0, len(audio_chunk), chunk_size):
74
- chunk = audio_chunk[i : i + chunk_size]
75
-
76
- # Check if it's the final chunk
77
- if i + chunk_size >= len(audio_chunk):
78
- chunk = audio_chunk[i:]
79
-
80
- # Avoid sending empty or repeated chunks
81
- if len(chunk) == 0:
82
- break
83
-
84
- # Pack and send the audio chunk
85
- packed_audio = struct.pack(f"{len(chunk)}f", *chunk)
86
- yield packed_audio
87
-
88
- # Ensure that no final word is repeated by not resending partial chunks
89
- if len(audio_chunk) % chunk_size != 0:
90
- remaining_chunk = audio_chunk[-(len(audio_chunk) % chunk_size) :]
91
- packed_audio = struct.pack(f"{len(remaining_chunk)}f", *remaining_chunk)
92
- yield packed_audio
93
-
94
-
95
- def handle_client(client_socket, processor):
96
- try:
97
- while True:
98
- # Receive data from the client
99
- data = client_socket.recv(1024).decode("utf-8")
100
- if not data:
101
- break
102
-
103
- try:
104
- # The client sends the text input
105
- text = data.strip()
106
-
107
- # Generate and stream audio chunks
108
- for audio_chunk in processor.generate_stream(text):
109
- client_socket.sendall(audio_chunk)
110
-
111
- # Send end-of-audio signal
112
- client_socket.sendall(b"END_OF_AUDIO")
113
-
114
- except Exception as inner_e:
115
- print(f"Error during processing: {inner_e}")
116
- traceback.print_exc() # Print the full traceback to diagnose the issue
117
- break
118
-
119
- except Exception as e:
120
- print(f"Error handling client: {e}")
121
- traceback.print_exc()
122
- finally:
123
- client_socket.close()
124
-
125
-
126
- def start_server(host, port, processor):
127
- server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
128
- server.bind((host, port))
129
- server.listen(5)
130
- print(f"Server listening on {host}:{port}")
131
-
132
- while True:
133
- client_socket, addr = server.accept()
134
- print(f"Accepted connection from {addr}")
135
- client_handler = Thread(target=handle_client, args=(client_socket, processor))
136
- client_handler.start()
137
-
138
-
139
- if __name__ == "__main__":
140
- try:
141
- # Load the model and vocoder using the provided files
142
- ckpt_file = "" # pointing your checkpoint "ckpts/model/model_1096.pt"
143
- vocab_file = "" # Add vocab file path if needed
144
- ref_audio = "" # add ref audio"./tests/ref_audio/reference.wav"
145
- ref_text = ""
146
-
147
- # Initialize the processor with the model and vocoder
148
- processor = TTSStreamingProcessor(
149
- ckpt_file=ckpt_file,
150
- vocab_file=vocab_file,
151
- ref_audio=ref_audio,
152
- ref_text=ref_text,
153
- dtype=torch.float32,
154
- )
155
-
156
- # Start the server
157
- start_server("0.0.0.0", 9998, processor)
158
- except KeyboardInterrupt:
159
- gc.collect()