Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import torch.nn.functional as F | |
import torchaudio | |
from safetensors.torch import load_file | |
from torch.nn.utils.rnn import pad_sequence | |
from torchdiffeq import odeint | |
from duration_predictor import SpeechLengthPredictor | |
from f5_tts.infer.utils_infer import (chunk_text, convert_char_to_pinyin, | |
hop_length, load_vocoder, | |
preprocess_ref_audio_text, speed, | |
target_rms, target_sample_rate, | |
transcribe) | |
# Import F5-TTS modules | |
from f5_tts.model import CFM, DiT, UNetT | |
from f5_tts.model.modules import MelSpec | |
from f5_tts.model.utils import (default, exists, get_tokenizer, lens_to_mask, | |
list_str_to_idx, list_str_to_tensor, | |
mask_from_frac_lengths) | |
# Import custom modules | |
from unimodel import UniModel | |
class DMOInference: | |
"""F5-TTS Inference wrapper class for easy text-to-speech generation.""" | |
def __init__( | |
self, | |
student_checkpoint_path="", | |
duration_predictor_path="", | |
device="cuda", | |
model_type="F5TTS_Base", # "F5TTS_Base" or "E2TTS_Base" | |
tokenizer="pinyin", | |
dataset_name="Emilia_ZH_EN", | |
): | |
""" | |
Initialize F5-TTS inference model. | |
Args: | |
student_checkpoint_path: Path to student model checkpoint | |
duration_predictor_path: Path to duration predictor checkpoint | |
device: Device to run inference on | |
model_type: Model architecture type | |
tokenizer: Tokenizer type ("pinyin", "char", or "custom") | |
dataset_name: Dataset name for tokenizer | |
cuda_device_id: CUDA device ID to use | |
""" | |
self.device = device | |
self.model_type = model_type | |
self.tokenizer = tokenizer | |
self.dataset_name = dataset_name | |
# Model parameters | |
self.target_sample_rate = 24000 | |
self.n_mel_channels = 100 | |
self.hop_length = 256 | |
self.real_guidance_scale = 2 | |
self.fake_guidance_scale = 0 | |
self.gen_cls_loss = False | |
self.num_student_step = 4 | |
# Initialize components | |
self._setup_tokenizer() | |
self._setup_models(student_checkpoint_path) | |
self._setup_mel_spec() | |
self._setup_vocoder() | |
self._setup_duration_predictor(duration_predictor_path) | |
def _setup_tokenizer(self): | |
"""Setup tokenizer and vocabulary.""" | |
if self.tokenizer == "custom": | |
tokenizer_path = self.tokenizer_path | |
else: | |
tokenizer_path = self.dataset_name | |
self.vocab_char_map, self.vocab_size = get_tokenizer( | |
tokenizer_path, self.tokenizer | |
) | |
def _setup_models(self, student_checkpoint_path): | |
"""Initialize teacher and student models.""" | |
# Model configuration | |
if self.model_type == "F5TTS_Base": | |
model_cls = DiT | |
model_cfg = dict( | |
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4 | |
) | |
elif self.model_type == "E2TTS_Base": | |
model_cls = UNetT | |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) | |
else: | |
raise ValueError(f"Unknown model type: {self.model_type}") | |
# Initialize UniModel (student) | |
self.model = UniModel( | |
model_cls( | |
**model_cfg, | |
text_num_embeds=self.vocab_size, | |
mel_dim=self.n_mel_channels, | |
second_time=self.num_student_step > 1, | |
), | |
checkpoint_path="", | |
vocab_char_map=self.vocab_char_map, | |
frac_lengths_mask=(0.5, 0.9), | |
real_guidance_scale=self.real_guidance_scale, | |
fake_guidance_scale=self.fake_guidance_scale, | |
gen_cls_loss=self.gen_cls_loss, | |
sway_coeff=0, | |
) | |
# Load student checkpoint | |
checkpoint = torch.load(student_checkpoint_path, map_location="cpu") | |
self.model.load_state_dict(checkpoint["model_state_dict"], strict=False) | |
# Setup generator and teacher | |
self.generator = self.model.feedforward_model.to(self.device) | |
self.teacher = self.model.guidance_model.real_unet.to(self.device) | |
self.scale = checkpoint["scale"] | |
def _setup_mel_spec(self): | |
"""Initialize mel spectrogram module.""" | |
mel_spec_kwargs = dict( | |
target_sample_rate=self.target_sample_rate, | |
n_mel_channels=self.n_mel_channels, | |
hop_length=self.hop_length, | |
) | |
self.mel_spec = MelSpec(**mel_spec_kwargs) | |
def _setup_vocoder(self): | |
"""Initialize vocoder.""" | |
self.vocos = load_vocoder(is_local=False, local_path="") | |
self.vocos = self.vocos.to(self.device) | |
def _setup_duration_predictor(self, checkpoint_path): | |
"""Initialize duration predictor.""" | |
self.wav2mel = MelSpec( | |
target_sample_rate=24000, | |
n_mel_channels=100, | |
hop_length=256, | |
win_length=1024, | |
n_fft=1024, | |
mel_spec_type="vocos", | |
).to(self.device) | |
self.SLP = SpeechLengthPredictor( | |
vocab_size=2545, | |
n_mel=100, | |
hidden_dim=512, | |
n_text_layer=4, | |
n_cross_layer=4, | |
n_head=8, | |
output_dim=301, | |
).to(self.device) | |
self.SLP.eval() | |
self.SLP.load_state_dict( | |
torch.load(checkpoint_path, map_location="cpu")["model_state_dict"] | |
) | |
def predict_duration( | |
self, pmt_wav_path, tar_text, pmt_text, dp_softmax_range=0.7, temperature=0 | |
): | |
""" | |
Predict duration for target text based on prompt audio. | |
Args: | |
pmt_wav_path: Path to prompt audio | |
tar_text: Target text to generate | |
pmt_text: Prompt text | |
dp_softmax_range: softmax annliation range from rate-based duration | |
temperature: temperature for softmax sampling (if 0, will use argmax) | |
Returns: | |
Estimated duration in frames | |
""" | |
pmt_wav, sr = torchaudio.load(pmt_wav_path) | |
if sr != self.target_sample_rate: | |
resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate) | |
pmt_wav = resampler(pmt_wav) | |
if pmt_wav.size(0) > 1: | |
pmt_wav = pmt_wav[0].unsqueeze(0) | |
pmt_wav = pmt_wav.to(self.device) | |
pmt_mel = self.wav2mel(pmt_wav).permute(0, 2, 1) | |
tar_tokens = self._convert_to_pinyin(list(tar_text)) | |
pmt_tokens = self._convert_to_pinyin(list(pmt_text)) | |
# Calculate duration | |
ref_text_len = len(pmt_tokens) | |
gen_text_len = len(tar_tokens) | |
ref_audio_len = pmt_mel.size(1) | |
duration = int(ref_audio_len / ref_text_len * gen_text_len / speed) | |
duration = duration // 10 | |
min_duration = max(int(duration * dp_softmax_range), 0) | |
max_duration = min(int(duration * (1 + dp_softmax_range)), 301) | |
all_tokens = pmt_tokens + [" "] + tar_tokens | |
text_ids = list_str_to_idx([all_tokens], self.vocab_char_map).to(self.device) | |
text_ids = text_ids.masked_fill(text_ids == -1, self.vocab_size) | |
with torch.no_grad(): | |
predictions = self.SLP(text_ids=text_ids, mel=pmt_mel) | |
predictions = predictions[:, -1, :] | |
predictions[:, :min_duration] = float("-inf") | |
predictions[:, max_duration:] = float("-inf") | |
if temperature == 0: | |
est_label = predictions.argmax(-1)[..., -1].item() * 10 | |
else: | |
probs = torch.softmax(predictions / temperature, dim=-1) | |
sampled_idx = torch.multinomial( | |
probs.squeeze(0), num_samples=1 | |
) # Remove the -1 index | |
est_label = sampled_idx.item() * 10 | |
return est_label | |
def _convert_to_pinyin(self, char_list): | |
"""Convert character list to pinyin.""" | |
result = [] | |
for x in convert_char_to_pinyin(char_list): | |
result = result + x | |
while result[0] == " " and len(result) > 1: | |
result = result[1:] | |
return result | |
def generate( | |
self, | |
gen_text, | |
audio_path, | |
prompt_text=None, | |
teacher_steps=16, | |
teacher_stopping_time=0.07, | |
student_start_step=1, | |
duration=None, | |
dp_softmax_range=0.7, | |
temperature=0, | |
eta=1.0, | |
cfg_strength=2.0, | |
sway_coefficient=-1.0, | |
verbose=False, | |
): | |
""" | |
Generate speech from text using teacher-student distillation. | |
Args: | |
gen_text: Text to generate | |
audio_path: Path to prompt audio | |
prompt_text: Prompt text (if None, will use ASR) | |
teacher_steps: Number of teacher guidance steps | |
teacher_stopping_time: When to stop teacher sampling | |
student_start_step: When to start student sampling | |
duration: Total duration (if None, will predict) | |
dp_softmax_range: Duration predictor softmax range allowed around rate based duration | |
temperature: Temperature for duration predictor sampling (0 means use argmax) | |
eta: Stochasticity control (0=DDIM, 1=DDPM) | |
cfg_strength: Classifier-free guidance strength | |
sway_coefficient: Sway sampling coefficient | |
verbose: Output sampling steps | |
Returns: | |
Generated audio waveform | |
""" | |
if prompt_text is None: | |
prompt_text = transcribe(audio_path) | |
# Predict duration if not provided | |
if duration is None: | |
duration = self.predict_duration( | |
audio_path, gen_text, prompt_text, dp_softmax_range, temperature | |
) | |
# Preprocess audio and text | |
ref_audio, ref_text = preprocess_ref_audio_text(audio_path, prompt_text) | |
audio, sr = torchaudio.load(ref_audio) | |
if audio.shape[0] > 1: | |
audio = torch.mean(audio, dim=0, keepdim=True) | |
# Normalize audio | |
rms = torch.sqrt(torch.mean(torch.square(audio))) | |
if rms < target_rms: | |
audio = audio * target_rms / rms | |
if sr != self.target_sample_rate: | |
resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate) | |
audio = resampler(audio) | |
audio = audio.to(self.device) | |
# Prepare text | |
text_list = [ref_text + gen_text] | |
final_text_list = convert_char_to_pinyin(text_list) | |
# Calculate durations | |
ref_audio_len = audio.shape[-1] // self.hop_length | |
if duration is None: | |
ref_text_len = len(ref_text.encode("utf-8")) | |
gen_text_len = len(gen_text.encode("utf-8")) | |
duration = ref_audio_len + int( | |
ref_audio_len / ref_text_len * gen_text_len / speed | |
) | |
else: | |
duration = ref_audio_len + duration | |
if verbose: | |
print("audio:", audio.shape) | |
print("text:", final_text_list) | |
print("duration:", duration) | |
print("eta (stochasticity):", eta) # Print eta value for debugging | |
# Run inference | |
with torch.inference_mode(): | |
cond, text, step_cond, cond_mask, max_duration, duration_tensor = ( | |
self._prepare_inputs(audio, final_text_list, duration) | |
) | |
# Teacher-student sampling | |
if teacher_steps > 0 and student_start_step > 0: | |
if verbose: | |
print( | |
"Start teacher sampling with hybrid DDIM/DDPM (eta={})....".format( | |
eta | |
) | |
) | |
x1 = self._teacher_sampling( | |
step_cond, | |
text, | |
cond_mask, | |
max_duration, | |
duration_tensor, # Use duration_tensor | |
teacher_steps, | |
teacher_stopping_time, | |
eta, | |
cfg_strength, | |
verbose, | |
sway_coefficient, | |
) | |
else: | |
x1 = step_cond | |
if verbose: | |
print("Start student sampling...") | |
# Student sampling | |
x1 = self._student_sampling( | |
x1, cond, text, student_start_step, verbose, sway_coefficient | |
) | |
# Decode to audio | |
mel = x1.permute(0, 2, 1) * self.scale | |
generated_wave = self.vocos.decode(mel[..., cond_mask.sum() :]) | |
return generated_wave.cpu().numpy().squeeze() | |
def generate_teacher_only( | |
self, | |
gen_text, | |
audio_path, | |
prompt_text=None, | |
teacher_steps=32, | |
duration=None, | |
eta=1.0, | |
cfg_strength=2.0, | |
sway_coefficient=-1.0, | |
): | |
""" | |
Generate speech using teacher model only (no student distillation). | |
Args: | |
gen_text: Text to generate | |
audio_path: Path to prompt audio | |
prompt_text: Prompt text (if None, will use ASR) | |
teacher_steps: Number of sampling steps | |
duration: Total duration (if None, will predict) | |
eta: Stochasticity control (0=DDIM, 1=DDPM) | |
cfg_strength: Classifier-free guidance strength | |
sway_coefficient: Sway sampling coefficient | |
Returns: | |
Generated audio waveform | |
""" | |
if prompt_text is None: | |
prompt_text = transcribe(audio_path) | |
# Predict duration if not provided | |
if duration is None: | |
duration = self.predict_duration(audio_path, gen_text, prompt_text) | |
# Preprocess audio and text | |
ref_audio, ref_text = preprocess_ref_audio_text(audio_path, prompt_text) | |
audio, sr = torchaudio.load(ref_audio) | |
if audio.shape[0] > 1: | |
audio = torch.mean(audio, dim=0, keepdim=True) | |
# Normalize audio | |
rms = torch.sqrt(torch.mean(torch.square(audio))) | |
if rms < target_rms: | |
audio = audio * target_rms / rms | |
if sr != self.target_sample_rate: | |
resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate) | |
audio = resampler(audio) | |
audio = audio.to(self.device) | |
# Prepare text | |
text_list = [ref_text + gen_text] | |
final_text_list = convert_char_to_pinyin(text_list) | |
# Calculate durations | |
ref_audio_len = audio.shape[-1] // self.hop_length | |
if duration is None: | |
ref_text_len = len(ref_text.encode("utf-8")) | |
gen_text_len = len(gen_text.encode("utf-8")) | |
duration = ref_audio_len + int( | |
ref_audio_len / ref_text_len * gen_text_len / speed | |
) | |
else: | |
duration = ref_audio_len + duration | |
# Run inference | |
with torch.inference_mode(): | |
cond, text, step_cond, cond_mask, max_duration = self._prepare_inputs( | |
audio, final_text_list, duration | |
) | |
# Teacher-only sampling | |
x1 = self._teacher_sampling( | |
step_cond, | |
text, | |
cond_mask, | |
max_duration, | |
duration, | |
teacher_steps, | |
1.0, | |
eta, | |
cfg_strength, | |
sway_coefficient, # stopping_time=1.0 for full sampling | |
) | |
# Decode to audio | |
mel = x1.permute(0, 2, 1) * self.scale | |
generated_wave = self.vocos.decode(mel[..., cond_mask.sum() :]) | |
return generated_wave | |
def _prepare_inputs(self, audio, text_list, duration): | |
"""Prepare inputs for generation.""" | |
lens = None | |
max_duration_limit = 4096 | |
cond = audio | |
text = text_list | |
if cond.ndim == 2: | |
cond = self.mel_spec(cond) | |
cond = cond.permute(0, 2, 1) | |
assert cond.shape[-1] == 100 | |
cond = cond / self.scale | |
batch, cond_seq_len, device = *cond.shape[:2], cond.device | |
if not exists(lens): | |
lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) | |
# Process text | |
if isinstance(text, list): | |
if exists(self.vocab_char_map): | |
text = list_str_to_idx(text, self.vocab_char_map).to(device) | |
else: | |
text = list_str_to_tensor(text).to(device) | |
assert text.shape[0] == batch | |
if exists(text): | |
text_lens = (text != -1).sum(dim=-1) | |
lens = torch.maximum(text_lens, lens) | |
# Process duration | |
cond_mask = lens_to_mask(lens) | |
if isinstance(duration, int): | |
duration = torch.full((batch,), duration, device=device, dtype=torch.long) | |
duration = torch.maximum(lens + 1, duration) | |
duration = duration.clamp(max=max_duration_limit) | |
max_duration = duration.amax() | |
# Pad conditioning | |
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) | |
cond_mask = F.pad( | |
cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False | |
) | |
cond_mask = cond_mask.unsqueeze(-1) | |
step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) | |
return cond, text, step_cond, cond_mask, max_duration, duration | |
def _teacher_sampling( | |
self, | |
step_cond, | |
text, | |
cond_mask, | |
max_duration, | |
duration, | |
teacher_steps, | |
teacher_stopping_time, | |
eta, | |
cfg_strength, | |
verbose, | |
sway_sampling_coef=-1, | |
): | |
"""Perform teacher model sampling.""" | |
device = step_cond.device | |
# Pre-generate noise sequence for stochastic sampling | |
noise_seq = None | |
if eta > 0: | |
noise_seq = [ | |
torch.randn(1, max_duration, 100, device=device) | |
for _ in range(teacher_steps) | |
] | |
def fn(t, x): | |
with torch.inference_mode(): | |
with torch.autocast(device_type="cuda", dtype=torch.float16): | |
if verbose: | |
print(f"current t: {t}") | |
step_frac = 1.0 - t.item() | |
step_idx = ( | |
min(int(step_frac * len(noise_seq)), len(noise_seq) - 1) | |
if noise_seq | |
else 0 | |
) | |
# Predict flow | |
pred = self.teacher( | |
x=x, | |
cond=step_cond, | |
text=text, | |
time=t, | |
mask=None, | |
drop_audio_cond=False, | |
drop_text=False, | |
) | |
if cfg_strength > 1e-5: | |
null_pred = self.teacher( | |
x=x, | |
cond=step_cond, | |
text=text, | |
time=t, | |
mask=None, | |
drop_audio_cond=True, | |
drop_text=True, | |
) | |
pred = pred + (pred - null_pred) * cfg_strength | |
# Add stochasticity if eta > 0 | |
if eta > 0 and noise_seq is not None: | |
alpha_t = 1.0 - t.item() | |
sigma_t = t.item() | |
noise_scale = torch.sqrt( | |
torch.tensor( | |
(sigma_t**2) / (alpha_t**2 + sigma_t**2) * eta, | |
device=device, | |
) | |
) | |
return pred + noise_scale * noise_seq[step_idx] | |
else: | |
return pred | |
# Initialize noise | |
y0 = [] | |
for dur in duration: | |
y0.append(torch.randn(dur, 100, device=device, dtype=step_cond.dtype)) | |
y0 = pad_sequence(y0, padding_value=0, batch_first=True) | |
# Setup time steps | |
t = torch.linspace( | |
0, 1, teacher_steps + 1, device=device, dtype=step_cond.dtype | |
) | |
if sway_sampling_coef is not None: | |
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) | |
t = t[: (t > teacher_stopping_time).float().argmax() + 2] | |
t = t[:-1] | |
# Solve ODE | |
trajectory = odeint(fn, y0, t, method="euler") | |
if teacher_stopping_time < 1.0: | |
# If early stopping, compute final step | |
pred = fn(t[-1], trajectory[-1]) | |
test_out = trajectory[-1] + (1 - t[-1]) * pred | |
return test_out | |
else: | |
return trajectory[-1] | |
def _student_sampling( | |
self, x1, cond, text, student_start_step, verbose, sway_coeff=-1 | |
): | |
"""Perform student model sampling.""" | |
steps = torch.Tensor([0, 0.25, 0.5, 0.75]) | |
steps = steps + sway_coeff * (torch.cos(torch.pi / 2 * steps) - 1 + steps) | |
steps = steps[student_start_step:] | |
for step in steps: | |
time = torch.Tensor([step]).to(x1.device) | |
x0 = torch.randn_like(x1) | |
t = time.unsqueeze(-1).unsqueeze(-1) | |
phi = (1 - t) * x0 + t * x1 | |
if verbose: | |
print(f"current step: {step}") | |
with torch.no_grad(): | |
pred = self.generator( | |
x=phi, | |
cond=cond, | |
text=text, | |
time=time, | |
drop_audio_cond=False, | |
drop_text=False, | |
) | |
# Predicted mel spectrogram | |
output = phi + (1 - t) * pred | |
x1 = output | |
return x1 | |