#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """ Split a large file into a train and valid set while respecting document boundaries. Documents should be separated by a single empty line. """ import argparse import random import sys def main(): parser = argparse.ArgumentParser() parser.add_argument("input") parser.add_argument("sample_output", help="train output file") parser.add_argument("remainder_output", help="valid output file") parser.add_argument("-k", type=int, help="remainder size") parser.add_argument( "--lines", action="store_true", help="split lines instead of docs" ) args = parser.parse_args() assert args.k is not None sample = [] remainder = [] num_docs = [0] def update_sample(doc): if len(sample) < args.k: sample.append(doc.copy()) else: i = num_docs[0] j = random.randrange(i + 1) if j < args.k: remainder.append(sample[j]) sample[j] = doc.copy() else: remainder.append(doc.copy()) num_docs[0] += 1 doc.clear() with open(args.input, "r", encoding="utf-8") as h: doc = [] for i, line in enumerate(h): if line.strip() == "": # empty line indicates new document update_sample(doc) else: doc.append(line) if args.lines: update_sample(doc) if i % 1000000 == 0: print(i, file=sys.stderr, end="", flush=True) elif i % 100000 == 0: print(".", file=sys.stderr, end="", flush=True) if len(doc) > 0: update_sample(doc) print(file=sys.stderr, flush=True) assert len(sample) == args.k with open(args.sample_output, "w", encoding="utf-8") as out: first = True for doc in sample: if not first and not args.lines: out.write("\n") first = False for line in doc: out.write(line) with open(args.remainder_output, "w", encoding="utf-8") as out: first = True for doc in remainder: if not first and not args.lines: out.write("\n") first = False for line in doc: out.write(line) if __name__ == "__main__": main()