fish-speech-1 / tools /vqgan /create_train_split.py
PoTaTo721's picture
update to 1.2
69e8a46
raw
history blame
3.01 kB
import math
from pathlib import Path
from random import Random
import click
from loguru import logger
from pydub import AudioSegment
from tqdm import tqdm
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
@click.command()
@click.argument("root", type=click.Path(exists=True, path_type=Path))
@click.option("--val-ratio", type=float, default=None)
@click.option("--val-count", type=int, default=None)
@click.option("--filelist", default=None, type=Path)
@click.option("--min-duration", default=None, type=float)
@click.option("--max-duration", default=None, type=float)
def main(root, val_ratio, val_count, filelist, min_duration, max_duration):
if filelist:
files = [i[0] for i in load_filelist(filelist)]
else:
files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
if min_duration is None and max_duration is None:
filtered_files = list(map(str, [file.relative_to(root) for file in files]))
else:
filtered_files = []
for file in tqdm(files):
try:
audio = AudioSegment.from_file(str(file))
duration = len(audio) / 1000.0
if min_duration is not None and duration < min_duration:
logger.info(
f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}"
)
continue
if max_duration is not None and duration > max_duration:
logger.info(
f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}"
)
continue
filtered_files.append(str(file.relative_to(root)))
except Exception as e:
logger.info(f"Error processing {file}: {e}")
logger.info(
f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering"
)
Random(42).shuffle(filtered_files)
if val_count is None and val_ratio is None:
logger.info("Validation ratio and count not specified, using min(20%, 100)")
val_size = min(100, math.ceil(len(filtered_files) * 0.2))
elif val_count is not None and val_ratio is not None:
logger.error("Cannot specify both val_count and val_ratio")
return
elif val_count is not None:
if val_count < 1 or val_count > len(filtered_files):
logger.error("val_count must be between 1 and number of files")
return
val_size = val_count
else:
val_size = math.ceil(len(filtered_files) * val_ratio)
logger.info(f"Using {val_size} files for validation")
with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
f.write("\n".join(filtered_files[val_size:]))
with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
f.write("\n".join(filtered_files[:val_size]))
logger.info("Done")
if __name__ == "__main__":
main()