#!/usr/bin/env python3 import math import os import random import braceexpand import webdataset as wds DEFAULT_CATS_FILE = os.path.join(os.path.dirname(__file__), '..', 'configs', 'places2-categories_157.txt') def is_good_key(key, cats): return any(c in key for c in cats) def main(args): if args.categories == 'nofilter': good_categories = None else: with open(args.categories, 'r') as f: good_categories = set(line.strip().split(' ')[0] for line in f if line.strip()) all_input_files = list(braceexpand.braceexpand(args.infile)) chunk_size = int(math.ceil(len(all_input_files) / args.n_read_streams)) input_iterators = [iter(wds.Dataset(all_input_files[start : start + chunk_size]).shuffle(args.shuffle_buffer)) for start in range(0, len(all_input_files), chunk_size)] output_datasets = [wds.ShardWriter(args.outpattern.format(i)) for i in range(args.n_write_streams)] good_readers = list(range(len(input_iterators))) step_i = 0 good_samples = 0 bad_samples = 0 while len(good_readers) > 0: if step_i % args.print_freq == 0: print(f'Iterations done {step_i}; readers alive {good_readers}; good samples {good_samples}; bad samples {bad_samples}') step_i += 1 ri = random.choice(good_readers) try: sample = next(input_iterators[ri]) except StopIteration: good_readers = list(set(good_readers) - {ri}) continue if good_categories is not None and not is_good_key(sample['__key__'], good_categories): bad_samples += 1 continue wi = random.randint(0, args.n_write_streams - 1) output_datasets[wi].write(sample) good_samples += 1 if __name__ == '__main__': import argparse aparser = argparse.ArgumentParser() aparser.add_argument('--categories', type=str, default=DEFAULT_CATS_FILE) aparser.add_argument('--shuffle-buffer', type=int, default=10000) aparser.add_argument('--n-read-streams', type=int, default=10) aparser.add_argument('--n-write-streams', type=int, default=10) aparser.add_argument('--print-freq', type=int, default=1000) aparser.add_argument('infile', type=str) aparser.add_argument('outpattern', type=str) main(aparser.parse_args())