lama / bin /filter_sharded_dataset.py
AK391
files
d380b77
#!/usr/bin/env python3
import math
import os
import random
import braceexpand
import webdataset as wds
DEFAULT_CATS_FILE = os.path.join(os.path.dirname(__file__), '..', 'configs', 'places2-categories_157.txt')
def is_good_key(key, cats):
return any(c in key for c in cats)
def main(args):
if args.categories == 'nofilter':
good_categories = None
else:
with open(args.categories, 'r') as f:
good_categories = set(line.strip().split(' ')[0] for line in f if line.strip())
all_input_files = list(braceexpand.braceexpand(args.infile))
chunk_size = int(math.ceil(len(all_input_files) / args.n_read_streams))
input_iterators = [iter(wds.Dataset(all_input_files[start : start + chunk_size]).shuffle(args.shuffle_buffer))
for start in range(0, len(all_input_files), chunk_size)]
output_datasets = [wds.ShardWriter(args.outpattern.format(i)) for i in range(args.n_write_streams)]
good_readers = list(range(len(input_iterators)))
step_i = 0
good_samples = 0
bad_samples = 0
while len(good_readers) > 0:
if step_i % args.print_freq == 0:
print(f'Iterations done {step_i}; readers alive {good_readers}; good samples {good_samples}; bad samples {bad_samples}')
step_i += 1
ri = random.choice(good_readers)
try:
sample = next(input_iterators[ri])
except StopIteration:
good_readers = list(set(good_readers) - {ri})
continue
if good_categories is not None and not is_good_key(sample['__key__'], good_categories):
bad_samples += 1
continue
wi = random.randint(0, args.n_write_streams - 1)
output_datasets[wi].write(sample)
good_samples += 1
if __name__ == '__main__':
import argparse
aparser = argparse.ArgumentParser()
aparser.add_argument('--categories', type=str, default=DEFAULT_CATS_FILE)
aparser.add_argument('--shuffle-buffer', type=int, default=10000)
aparser.add_argument('--n-read-streams', type=int, default=10)
aparser.add_argument('--n-write-streams', type=int, default=10)
aparser.add_argument('--print-freq', type=int, default=1000)
aparser.add_argument('infile', type=str)
aparser.add_argument('outpattern', type=str)
main(aparser.parse_args())