|
import os |
|
import sys |
|
from random import randint |
|
from typing import List, Optional, Set, Union |
|
|
|
from tortoise.utils.audio import get_voices, load_audio, load_voices |
|
from tortoise.utils.text import split_and_recombine_text |
|
|
|
|
|
def get_all_voices(extra_voice_dirs_str: str = ""): |
|
extra_voice_dirs = extra_voice_dirs_str.split(",") if extra_voice_dirs_str else [] |
|
return sorted(get_voices(extra_voice_dirs)), extra_voice_dirs |
|
|
|
|
|
def parse_voice_str(voice_str: str, all_voices: List[str]): |
|
selected_voices = all_voices if voice_str == "all" else voice_str.split(",") |
|
selected_voices = [v.split("&") if "&" in v else [v] for v in selected_voices] |
|
for voices in selected_voices: |
|
for v in voices: |
|
if v != "random" and v not in all_voices: |
|
raise ValueError( |
|
f"voice {v} not available, use --list-voices to see available voices." |
|
) |
|
|
|
return selected_voices |
|
|
|
|
|
def voice_loader(selected_voices: list, extra_voice_dirs: List[str]): |
|
for voices in selected_voices: |
|
yield voices, *load_voices(voices, extra_voice_dirs) |
|
|
|
|
|
def parse_multiarg_text(text: List[str]): |
|
return (" ".join(text) if text else "".join(line for line in sys.stdin)).strip() |
|
|
|
|
|
def split_text(text: str, text_split: str): |
|
if text_split: |
|
desired_length, max_length = map(int, text_split.split(",")) |
|
if desired_length > max_length: |
|
raise ValueError( |
|
f"--text-split: desired_length ({desired_length}) must be <= max_length ({max_length})" |
|
) |
|
texts = split_and_recombine_text(text, desired_length, max_length) |
|
else: |
|
texts = split_and_recombine_text(text) |
|
|
|
if not texts: |
|
raise ValueError("no text provided") |
|
return texts |
|
|
|
|
|
def validate_output_dir(output_dir: str, selected_voices: list, candidates: int): |
|
if output_dir: |
|
os.makedirs(output_dir, exist_ok=True) |
|
else: |
|
if len(selected_voices) > 1: |
|
raise ValueError('cannot have multiple voices without --output-dir"') |
|
if candidates > 1: |
|
raise ValueError('cannot have multiple candidates without --output-dir"') |
|
return output_dir |
|
|
|
|
|
def check_pydub(play: bool): |
|
if play: |
|
try: |
|
import pydub |
|
import pydub.playback |
|
|
|
return pydub |
|
except ImportError: |
|
raise RuntimeError( |
|
'--play requires pydub to be installed, which can be done with "pip install pydub"' |
|
) |
|
|
|
|
|
def get_seed(seed: Optional[int]): |
|
return randint(0, 2**32 - 1) if seed is None else seed |
|
|
|
|
|
from pathlib import Path |
|
from typing import Any, Callable |
|
|
|
import torch |
|
import torchaudio |
|
|
|
|
|
def run_and_save_tts( |
|
call_tts, |
|
text, |
|
output_dir: Path, |
|
return_deterministic_state, |
|
return_filepaths=False, |
|
voicefixer=True, |
|
): |
|
output_dir.mkdir(exist_ok=True) |
|
if return_deterministic_state: |
|
gen, dbg = call_tts(text) |
|
torch.save(dbg, output_dir / "dbg.pt") |
|
else: |
|
gen = call_tts(text) |
|
|
|
if not isinstance(gen, list): |
|
gen = [gen] |
|
gen = [g.squeeze(0).cpu() for g in gen] |
|
fps = [] |
|
for i, g in enumerate(gen): |
|
fps.append(output_dir / f"{i}.wav") |
|
save_gen_with_voicefix(g, fps[-1], squeeze=False, voicefixer=voicefixer) |
|
|
|
return fps if return_filepaths else gen |
|
|
|
|
|
def infer_on_texts( |
|
call_tts: Callable[[str], Any], |
|
texts: List[str], |
|
output_dir: Union[str, Path], |
|
return_deterministic_state: bool, |
|
lines_to_regen: Set[int], |
|
logger=print, |
|
return_filepaths=False, |
|
voicefixer=True, |
|
): |
|
audio_chunks = [] |
|
base_p = Path(output_dir) |
|
base_p.mkdir(exist_ok=True) |
|
|
|
for text_idx, text in enumerate(texts): |
|
line_p = base_p / f"{text_idx}" |
|
line_p.mkdir(exist_ok=True) |
|
|
|
if text_idx not in lines_to_regen: |
|
files = list(line_p.glob("*.wav")) |
|
if files: |
|
logger(f"loading existing audio fragments for [{text_idx}]") |
|
audio_chunks.append([load_audio(str(f), 24000) for f in files]) |
|
continue |
|
else: |
|
logger(f"no existing audio fragment for [{text_idx}]") |
|
|
|
logger(f"generating audio for text {text_idx}: {text}") |
|
audio_chunks.append( |
|
run_and_save_tts( |
|
call_tts, |
|
text, |
|
line_p, |
|
return_deterministic_state, |
|
voicefixer=voicefixer, |
|
) |
|
) |
|
|
|
fnames = [] |
|
results = [] |
|
for i in range(len(audio_chunks[0])): |
|
resultant = torch.cat([c[i] for c in audio_chunks], dim=-1) |
|
fnames.append(base_p / f"combined-{i}.wav") |
|
save_gen_with_voicefix( |
|
resultant, fnames[-1], squeeze=False, voicefixer=False |
|
) |
|
results.append(resultant) |
|
|
|
return fnames if return_filepaths else results |
|
|
|
|
|
from voicefixer import VoiceFixer |
|
|
|
vfixer = VoiceFixer() |
|
|
|
|
|
def save_gen_with_voicefix(g, fpath, squeeze=True, voicefixer=True): |
|
torchaudio.save(fpath, g.squeeze(0).cpu() if squeeze else g, 24000, format="wav") |
|
if voicefixer: |
|
vfixer.restore( |
|
input=fpath, |
|
output=fpath, |
|
cuda=True, |
|
mode=0, |
|
|
|
) |
|
|