File size: 3,412 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
#!/usr/bin/env python3
import argparse
from collections import Counter
from itertools import zip_longest
import logging
from pathlib import Path
import sys
from typing import List
from typing import Optional
from espnet.utils.cli_utils import get_commandline_args
def split_scps(
scps: List[str],
num_splits: int,
names: Optional[List[str]],
output_dir: str,
log_level: str,
):
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if num_splits < 2:
raise RuntimeError(f"{num_splits} must be more than 1")
if names is None:
names = [Path(s).name for s in scps]
if len(set(names)) != len(names):
raise RuntimeError(f"names are duplicated: {names}")
for name in names:
(Path(output_dir) / name).mkdir(parents=True, exist_ok=True)
scp_files = [open(s, "r", encoding="utf-8") for s in scps]
# Remove existing files
for n in range(num_splits):
for name in names:
if (Path(output_dir) / name / f"split.{n}").exists():
(Path(output_dir) / name / f"split.{n}").unlink()
counter = Counter()
linenum = -1
for linenum, lines in enumerate(zip_longest(*scp_files)):
if any(line is None for line in lines):
raise RuntimeError("Number of lines are mismatched")
prev_key = None
for line in lines:
key = line.rstrip().split(maxsplit=1)[0]
if prev_key is not None and prev_key != key:
raise RuntimeError("Not sorted or not having same keys")
# Select a piece from split texts alternatively
num = linenum % num_splits
counter[num] += 1
# Write lines respectively
for line, name in zip(lines, names):
# To reduce the number of opened file descriptors, open now
with (Path(output_dir) / name / f"split.{num}").open(
"a", encoding="utf-8"
) as f:
f.write(line)
if linenum + 1 < num_splits:
raise RuntimeError(
f"The number of lines is less than num_splits: {linenum + 1} < {num_splits}"
)
for name in names:
with (Path(output_dir) / name / "num_splits").open("w", encoding="utf-8") as f:
f.write(str(num_splits))
logging.info(f"N lines of split text: {set(counter.values())}")
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Split scp files",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--scps", required=True, help="Input texts", nargs="+")
parser.add_argument("--names", help="Output names for each files", nargs="+")
parser.add_argument("--num_splits", help="Split number", type=int)
parser.add_argument("--output_dir", required=True, help="Output directory")
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
split_scps(**kwargs)
if __name__ == "__main__":
main()
|