tortoise5x / tortoise /inference.py
djkesu's picture
Added all documents for inference
3e423f6
raw
history blame
No virus
5.54 kB
import os
import sys
from random import randint
from typing import List, Optional, Set, Union
from tortoise.utils.audio import get_voices, load_audio, load_voices
from tortoise.utils.text import split_and_recombine_text
def get_all_voices(extra_voice_dirs_str: str = ""):
extra_voice_dirs = extra_voice_dirs_str.split(",") if extra_voice_dirs_str else []
return sorted(get_voices(extra_voice_dirs)), extra_voice_dirs
def parse_voice_str(voice_str: str, all_voices: List[str]):
selected_voices = all_voices if voice_str == "all" else voice_str.split(",")
selected_voices = [v.split("&") if "&" in v else [v] for v in selected_voices]
for voices in selected_voices:
for v in voices:
if v != "random" and v not in all_voices:
raise ValueError(
f"voice {v} not available, use --list-voices to see available voices."
)
return selected_voices
def voice_loader(selected_voices: list, extra_voice_dirs: List[str]):
for voices in selected_voices:
yield voices, *load_voices(voices, extra_voice_dirs)
def parse_multiarg_text(text: List[str]):
return (" ".join(text) if text else "".join(line for line in sys.stdin)).strip()
def split_text(text: str, text_split: str):
if text_split:
desired_length, max_length = map(int, text_split.split(","))
if desired_length > max_length:
raise ValueError(
f"--text-split: desired_length ({desired_length}) must be <= max_length ({max_length})"
)
texts = split_and_recombine_text(text, desired_length, max_length)
else:
texts = split_and_recombine_text(text)
#
if not texts:
raise ValueError("no text provided")
return texts
def validate_output_dir(output_dir: str, selected_voices: list, candidates: int):
if output_dir:
os.makedirs(output_dir, exist_ok=True)
else:
if len(selected_voices) > 1:
raise ValueError('cannot have multiple voices without --output-dir"')
if candidates > 1:
raise ValueError('cannot have multiple candidates without --output-dir"')
return output_dir
def check_pydub(play: bool):
if play:
try:
import pydub
import pydub.playback
return pydub
except ImportError:
raise RuntimeError(
'--play requires pydub to be installed, which can be done with "pip install pydub"'
)
def get_seed(seed: Optional[int]):
return randint(0, 2**32 - 1) if seed is None else seed
from pathlib import Path
from typing import Any, Callable
import torch
import torchaudio
def run_and_save_tts(
call_tts,
text,
output_dir: Path,
return_deterministic_state,
return_filepaths=False,
voicefixer=True,
):
output_dir.mkdir(exist_ok=True)
if return_deterministic_state:
gen, dbg = call_tts(text)
torch.save(dbg, output_dir / "dbg.pt")
else:
gen = call_tts(text)
#
if not isinstance(gen, list):
gen = [gen]
gen = [g.squeeze(0).cpu() for g in gen]
fps = []
for i, g in enumerate(gen):
fps.append(output_dir / f"{i}.wav")
save_gen_with_voicefix(g, fps[-1], squeeze=False, voicefixer=voicefixer)
# torchaudio.save(output_dir/f'{i}.wav', g, 24000)
return fps if return_filepaths else gen
def infer_on_texts(
call_tts: Callable[[str], Any],
texts: List[str],
output_dir: Union[str, Path],
return_deterministic_state: bool,
lines_to_regen: Set[int],
logger=print,
return_filepaths=False,
voicefixer=True,
):
audio_chunks = []
base_p = Path(output_dir)
base_p.mkdir(exist_ok=True)
for text_idx, text in enumerate(texts):
line_p = base_p / f"{text_idx}"
line_p.mkdir(exist_ok=True)
#
if text_idx not in lines_to_regen:
files = list(line_p.glob("*.wav"))
if files:
logger(f"loading existing audio fragments for [{text_idx}]")
audio_chunks.append([load_audio(str(f), 24000) for f in files])
continue
else:
logger(f"no existing audio fragment for [{text_idx}]")
#
logger(f"generating audio for text {text_idx}: {text}")
audio_chunks.append(
run_and_save_tts(
call_tts,
text,
line_p,
return_deterministic_state,
voicefixer=voicefixer,
)
)
fnames = []
results = []
for i in range(len(audio_chunks[0])):
resultant = torch.cat([c[i] for c in audio_chunks], dim=-1)
fnames.append(base_p / f"combined-{i}.wav")
save_gen_with_voicefix(
resultant, fnames[-1], squeeze=False, voicefixer=False
) # do not run fix on combined!!
results.append(resultant)
# torchaudio.save(base_p/'combined.wav', resultant, 24000)
return fnames if return_filepaths else results
from voicefixer import VoiceFixer
vfixer = VoiceFixer()
def save_gen_with_voicefix(g, fpath, squeeze=True, voicefixer=True):
torchaudio.save(fpath, g.squeeze(0).cpu() if squeeze else g, 24000, format="wav")
if voicefixer:
vfixer.restore(
input=fpath,
output=fpath,
cuda=True,
mode=0,
# your_vocoder_func = convert_mel_to_wav # TODO test if integration with unvinet improves things
)