File size: 1,694 Bytes
ee0ec3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#!/usr/bin/env python3
import argparse
import os
import random
import sys


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("paths", nargs="+", type=str, help="Input CoNLL-U files")
    parser.add_argument("--ratio", default=0.95, type=float, help="Train/dev ratio")
    parser.add_argument("--train", default=None, type=str, help="Train file to append to")
    parser.add_argument("--dev", default=None, type=str, help="Dev file to append to")
    parser.add_argument("--seed", default=42, type=int, help="Random seed")
    args = parser.parse_args()

    train, dev = [], []

    for path in args.paths:
        print("Loading {}".format(path), file=sys.stderr)

        with open(path, "r", encoding="utf-8") as input_file:
            blocks, block = [], []
            header = "# newdoc id = {}\n".format(os.path.basename(path))
            for line in input_file:
                if line.startswith("# variant "): continue
                block.append(line)
                if not line.rstrip("\n"):
                    block = "".join(block)
                    if "greek.expression" not in block:
                        blocks.append(block)
                    block = []
        assert not block

        random.Random(args.seed).shuffle(blocks)
        train.extend([header] + blocks[:int(args.ratio * len(blocks))])
        if args.ratio < 1:
            dev.extend([header] + blocks[int(args.ratio * len(blocks)):])

    for path, blocks in [(args.train, train), (args.dev, dev)]:
        if not blocks: continue

        with open(path, "a", encoding="utf-8") as output_file:
            print("".join(blocks), end="", file=output_file)