File size: 3,412 Bytes
ad16788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/env python3
import argparse
from collections import Counter
from itertools import zip_longest
import logging
from pathlib import Path
import sys
from typing import List
from typing import Optional

from espnet.utils.cli_utils import get_commandline_args


def split_scps(
    scps: List[str],
    num_splits: int,
    names: Optional[List[str]],
    output_dir: str,
    log_level: str,
):
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if num_splits < 2:
        raise RuntimeError(f"{num_splits} must be more than 1")

    if names is None:
        names = [Path(s).name for s in scps]
        if len(set(names)) != len(names):
            raise RuntimeError(f"names are duplicated: {names}")

    for name in names:
        (Path(output_dir) / name).mkdir(parents=True, exist_ok=True)

    scp_files = [open(s, "r", encoding="utf-8") for s in scps]
    # Remove existing files
    for n in range(num_splits):
        for name in names:
            if (Path(output_dir) / name / f"split.{n}").exists():
                (Path(output_dir) / name / f"split.{n}").unlink()

    counter = Counter()
    linenum = -1
    for linenum, lines in enumerate(zip_longest(*scp_files)):
        if any(line is None for line in lines):
            raise RuntimeError("Number of lines are mismatched")

        prev_key = None
        for line in lines:
            key = line.rstrip().split(maxsplit=1)[0]
            if prev_key is not None and prev_key != key:
                raise RuntimeError("Not sorted or not having same keys")

        # Select a piece from split texts alternatively
        num = linenum % num_splits
        counter[num] += 1
        # Write lines respectively
        for line, name in zip(lines, names):
            # To reduce the number of opened file descriptors, open now
            with (Path(output_dir) / name / f"split.{num}").open(
                "a", encoding="utf-8"
            ) as f:
                f.write(line)

    if linenum + 1 < num_splits:
        raise RuntimeError(
            f"The number of lines is less than num_splits: {linenum + 1} < {num_splits}"
        )

    for name in names:
        with (Path(output_dir) / name / "num_splits").open("w", encoding="utf-8") as f:
            f.write(str(num_splits))
    logging.info(f"N lines of split text: {set(counter.values())}")


def get_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Split scp files",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--log_level",
        type=lambda x: x.upper(),
        default="INFO",
        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
        help="The verbose level of logging",
    )

    parser.add_argument("--scps", required=True, help="Input texts", nargs="+")
    parser.add_argument("--names", help="Output names for each files", nargs="+")
    parser.add_argument("--num_splits", help="Split number", type=int)
    parser.add_argument("--output_dir", required=True, help="Output directory")
    return parser


def main(cmd=None):
    print(get_commandline_args(), file=sys.stderr)
    parser = get_parser()
    args = parser.parse_args(cmd)
    kwargs = vars(args)
    split_scps(**kwargs)


if __name__ == "__main__":
    main()