Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import torch.nn.functional as F | |
import torchaudio | |
from einops import rearrange | |
from vocos import Vocos | |
from model import CFM, UNetT, DiT, MMDiT | |
from model.utils import ( | |
load_checkpoint, | |
get_tokenizer, | |
convert_char_to_pinyin, | |
save_spectrogram, | |
) | |
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | |
# --------------------- Dataset Settings -------------------- # | |
target_sample_rate = 24000 | |
n_mel_channels = 100 | |
hop_length = 256 | |
target_rms = 0.1 | |
tokenizer = "pinyin" | |
dataset_name = "Emilia_ZH_EN" | |
# ---------------------- infer setting ---------------------- # | |
seed = None # int | None | |
exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base | |
ckpt_step = 1200000 | |
nfe_step = 32 # 16, 32 | |
cfg_strength = 2. | |
ode_method = 'euler' # euler | midpoint | |
sway_sampling_coef = -1. | |
speed = 1. | |
if exp_name == "F5TTS_Base": | |
model_cls = DiT | |
model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4) | |
elif exp_name == "E2TTS_Base": | |
model_cls = UNetT | |
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4) | |
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" | |
output_dir = "tests" | |
# [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment] | |
# pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git | |
# [write the origin_text into a file, e.g. tests/test_edit.txt] | |
# ctc-forced-aligner --audio_path "tests/ref_audio/test_en_1_ref_short.wav" --text_path "tests/test_edit.txt" --language "zho" --romanize --split_size "char" | |
# [result will be saved at same path of audio file] | |
# [--language "zho" for Chinese, "eng" for English] | |
# [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"] | |
audio_to_edit = "tests/ref_audio/test_en_1_ref_short.wav" | |
origin_text = "Some call me nature, others call me mother nature." | |
target_text = "Some call me optimist, others call me realist." | |
parts_to_edit = [[1.42, 2.44], [4.04, 4.9], ] # stard_ends of "nature" & "mother nature", in seconds | |
fix_duration = [1.2, 1, ] # fix duration for "optimist" & "realist", in seconds | |
# audio_to_edit = "tests/ref_audio/test_zh_1_ref_short.wav" | |
# origin_text = "对,这就是我,万人敬仰的太乙真人。" | |
# target_text = "对,那就是你,万人敬仰的太白金星。" | |
# parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ] | |
# fix_duration = None # use origin text duration | |
# -------------------------------------------------# | |
use_ema = True | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
# Vocoder model | |
local = False | |
if local: | |
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz" | |
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") | |
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device) | |
vocos.load_state_dict(state_dict) | |
vocos.eval() | |
else: | |
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") | |
# Tokenizer | |
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) | |
# Model | |
model = CFM( | |
transformer = model_cls( | |
**model_cfg, | |
text_num_embeds = vocab_size, | |
mel_dim = n_mel_channels | |
), | |
mel_spec_kwargs = dict( | |
target_sample_rate = target_sample_rate, | |
n_mel_channels = n_mel_channels, | |
hop_length = hop_length, | |
), | |
odeint_kwargs = dict( | |
method = ode_method, | |
), | |
vocab_char_map = vocab_char_map, | |
).to(device) | |
model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema) | |
# Audio | |
audio, sr = torchaudio.load(audio_to_edit) | |
if audio.shape[0] > 1: | |
audio = torch.mean(audio, dim=0, keepdim=True) | |
rms = torch.sqrt(torch.mean(torch.square(audio))) | |
if rms < target_rms: | |
audio = audio * target_rms / rms | |
if sr != target_sample_rate: | |
resampler = torchaudio.transforms.Resample(sr, target_sample_rate) | |
audio = resampler(audio) | |
offset = 0 | |
audio_ = torch.zeros(1, 0) | |
edit_mask = torch.zeros(1, 0, dtype=torch.bool) | |
for part in parts_to_edit: | |
start, end = part | |
part_dur = end - start if fix_duration is None else fix_duration.pop(0) | |
part_dur = part_dur * target_sample_rate | |
start = start * target_sample_rate | |
audio_ = torch.cat((audio_, audio[:, round(offset):round(start)], torch.zeros(1, round(part_dur))), dim = -1) | |
edit_mask = torch.cat((edit_mask, | |
torch.ones(1, round((start - offset) / hop_length), dtype = torch.bool), | |
torch.zeros(1, round(part_dur / hop_length), dtype = torch.bool) | |
), dim = -1) | |
offset = end * target_sample_rate | |
# audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1) | |
edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value = True) | |
audio = audio.to(device) | |
edit_mask = edit_mask.to(device) | |
# Text | |
text_list = [target_text] | |
if tokenizer == "pinyin": | |
final_text_list = convert_char_to_pinyin(text_list) | |
else: | |
final_text_list = [text_list] | |
print(f"text : {text_list}") | |
print(f"pinyin: {final_text_list}") | |
# Duration | |
ref_audio_len = 0 | |
duration = audio.shape[-1] // hop_length | |
# Inference | |
with torch.inference_mode(): | |
generated, trajectory = model.sample( | |
cond = audio, | |
text = final_text_list, | |
duration = duration, | |
steps = nfe_step, | |
cfg_strength = cfg_strength, | |
sway_sampling_coef = sway_sampling_coef, | |
seed = seed, | |
edit_mask = edit_mask, | |
) | |
print(f"Generated mel: {generated.shape}") | |
# Final result | |
generated = generated[:, ref_audio_len:, :] | |
generated_mel_spec = rearrange(generated, '1 n d -> 1 d n') | |
generated_wave = vocos.decode(generated_mel_spec.cpu()) | |
if rms < target_rms: | |
generated_wave = generated_wave * rms / target_rms | |
save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/test_single_edit.png") | |
torchaudio.save(f"{output_dir}/test_single_edit.wav", generated_wave, target_sample_rate) | |
print(f"Generated wav: {generated_wave.shape}") | |