import itertools import os import re from collections import defaultdict from functools import partial from multiprocessing import Pool from pathlib import Path import click import numpy as np from loguru import logger from tqdm import tqdm from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData from fish_speech.datasets.protos.text_data_stream import pack_pb_stream from fish_speech.utils.file import load_filelist # To avoid CPU overload os.environ["MKL_NUM_THREADS"] = "1" os.environ["OMP_NUM_THREADS"] = "1" def task_generator_folder(root: Path, text_extension: str): files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}")) files = sorted(files) grouped_files = defaultdict(list) for file in tqdm(files, desc=f"Grouping {root}"): p = str(file.parent) try: if isinstance(text_extension, str): texts = [file.with_suffix(text_extension).read_text()] else: texts = [file.with_suffix(ext).read_text() for ext in text_extension] except Exception as e: logger.error(f"Failed to read text {file}: {e}") continue grouped_files[p].append((file, texts)) logger.info( f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..." ) for name, subset in grouped_files.items(): yield name, subset, "folder" def task_generator_filelist(filelist): grouped_files = defaultdict(list) for filename, speaker, _, text in load_filelist(filelist): grouped_files[speaker].append((Path(filename), [text])) logger.info(f"Found {len(grouped_files)} groups in {filelist}") for speaker, values in grouped_files.items(): yield speaker, values, "filelist" def run_task(task): name, subset, source = task # Parse the files sentences = [] for file in subset: file, texts = file np_file = file.with_suffix(".npy") if np_file.exists() is False: logger.warning(f"Can't find {np_file}") continue new_texts = [] for text in texts: # Simple cleaning: replace { xxx } and < xxx > with space text = re.sub(r"\{.*?\}", " ", text) text = re.sub(r"<.*?>", " ", text) text = re.sub(r"\s+", " ", text) new_texts.append(text) try: semantics = np.load(np_file) except Exception as e: logger.error(f"Failed to parse {file}: {e}") continue if isinstance(semantics, np.ndarray): semantics = semantics.tolist() sentences.append( Sentence( texts=new_texts, semantics=[Semantics(values=s) for s in semantics], ) ) # Pack the sentences return pack_pb_stream( TextData( source=source, name=name, sentences=sentences, ) ) @click.command() @click.option( "--input", type=click.Path(path_type=Path), required=True, help="A folder containing the dataset or a filelist", multiple=True, ) @click.option( "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft" ) @click.option("--num-workers", type=int, default=16) @click.option("--text-extension", type=str, default=[".txt"], multiple=True) @click.option( "--shard-size", type=int, default=10, help="The maximum size of each shard in mb" ) def main(input, output, num_workers, text_extension, shard_size): generator_fns = [] for f in input: assert f.exists(), f"{f} not found" if f.is_dir(): generator_fn = task_generator_folder(f, text_extension) else: generator_fn = task_generator_filelist(f) generator_fns.append(generator_fn) generator_fn = itertools.chain(*generator_fns) output.mkdir(parents=True, exist_ok=True) dataset_fp = None tar_idx = 0 written_size = 0 with Pool(num_workers) as p: for result in tqdm(p.imap_unordered(run_task, generator_fn)): if dataset_fp is None: dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb") dataset_fp.write(result) written_size += len(result) if written_size > shard_size * 1024 * 1024: logger.info(f"Finished writing {tar_idx} shards to {output}") dataset_fp.close() dataset_fp = None written_size = 0 tar_idx += 1 if dataset_fp is not None: dataset_fp.close() logger.info(f"Finished writing {tar_idx + 1} shards to {output}") if __name__ == "__main__": main()