|
import itertools
|
|
import os
|
|
import re
|
|
from collections import defaultdict
|
|
from functools import partial
|
|
from multiprocessing import Pool
|
|
from pathlib import Path
|
|
|
|
import click
|
|
import numpy as np
|
|
from loguru import logger
|
|
from tqdm import tqdm
|
|
|
|
from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
|
|
from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
|
|
from tools.file import load_filelist
|
|
|
|
|
|
os.environ["MKL_NUM_THREADS"] = "1"
|
|
os.environ["OMP_NUM_THREADS"] = "1"
|
|
|
|
|
|
def task_generator_folder(root: Path, text_extension: str):
|
|
files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
|
|
files = sorted(files)
|
|
|
|
grouped_files = defaultdict(list)
|
|
for file in tqdm(files, desc=f"Grouping {root}"):
|
|
p = str(file.parent)
|
|
speaker = file.parent.name
|
|
|
|
try:
|
|
if isinstance(text_extension, str):
|
|
texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")]
|
|
else:
|
|
texts = [
|
|
file.with_suffix(ext).read_text(encoding="utf-8")
|
|
for ext in text_extension
|
|
]
|
|
except Exception as e:
|
|
logger.error(f"Failed to read text {file}: {e}")
|
|
continue
|
|
|
|
grouped_files[p].append((speaker, file, texts))
|
|
|
|
logger.info(
|
|
f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
|
|
)
|
|
|
|
for i in grouped_files.values():
|
|
subset = [(f, t) for _, f, t in i]
|
|
yield i[0][0], subset, "folder"
|
|
|
|
|
|
def task_generator_filelist(filelist):
|
|
grouped_files = defaultdict(list)
|
|
for filename, speaker, _, text in load_filelist(filelist):
|
|
grouped_files[speaker].append((Path(filename), [text]))
|
|
|
|
logger.info(f"Found {len(grouped_files)} groups in {filelist}")
|
|
for speaker, values in grouped_files.items():
|
|
yield speaker, values, "filelist"
|
|
|
|
|
|
def run_task(task):
|
|
name, subset, source = task
|
|
|
|
|
|
sentences = []
|
|
for file, texts in subset:
|
|
np_file = file.with_suffix(".npy")
|
|
if np_file.exists() is False:
|
|
logger.warning(f"Can't find {np_file}")
|
|
continue
|
|
|
|
new_texts = []
|
|
|
|
for text in texts:
|
|
|
|
text = re.sub(r"\{.*?\}", " ", text)
|
|
text = re.sub(r"<.*?>", " ", text)
|
|
text = re.sub(r"\s+", " ", text)
|
|
new_texts.append(text)
|
|
|
|
try:
|
|
semantics = np.load(np_file)
|
|
except Exception as e:
|
|
logger.error(f"Failed to parse {file}: {e}")
|
|
continue
|
|
|
|
if isinstance(semantics, np.ndarray):
|
|
semantics = semantics.tolist()
|
|
|
|
sentences.append(
|
|
Sentence(
|
|
texts=new_texts,
|
|
semantics=[Semantics(values=s) for s in semantics],
|
|
)
|
|
)
|
|
|
|
|
|
return pack_pb_stream(
|
|
TextData(
|
|
source=source,
|
|
name=name,
|
|
sentences=sentences,
|
|
)
|
|
)
|
|
|
|
|
|
@click.command()
|
|
@click.option(
|
|
"--input",
|
|
type=click.Path(path_type=Path),
|
|
required=True,
|
|
help="A folder containing the dataset or a filelist",
|
|
multiple=True,
|
|
)
|
|
@click.option(
|
|
"--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
|
|
)
|
|
@click.option("--num-workers", type=int, default=16)
|
|
@click.option("--text-extension", type=str, default=[".txt"], multiple=True)
|
|
@click.option(
|
|
"--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
|
|
)
|
|
def main(input, output, num_workers, text_extension, shard_size):
|
|
generator_fns = []
|
|
|
|
for f in input:
|
|
assert f.exists(), f"{f} not found"
|
|
|
|
if f.is_dir():
|
|
generator_fn = task_generator_folder(f, text_extension)
|
|
else:
|
|
generator_fn = task_generator_filelist(f)
|
|
|
|
generator_fns.append(generator_fn)
|
|
|
|
generator_fn = itertools.chain(*generator_fns)
|
|
output.mkdir(parents=True, exist_ok=True)
|
|
|
|
dataset_fp = None
|
|
tar_idx = 0
|
|
written_size = 0
|
|
|
|
with Pool(num_workers) as p:
|
|
for result in tqdm(p.imap_unordered(run_task, generator_fn)):
|
|
if dataset_fp is None:
|
|
dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
|
|
|
|
dataset_fp.write(result)
|
|
written_size += len(result)
|
|
|
|
if written_size > shard_size * 1024 * 1024:
|
|
logger.info(f"Finished writing {tar_idx} shards to {output}")
|
|
dataset_fp.close()
|
|
dataset_fp = None
|
|
written_size = 0
|
|
tar_idx += 1
|
|
|
|
if dataset_fp is not None:
|
|
dataset_fp.close()
|
|
|
|
logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|