import argparse import multiprocessing import os import time import mxnet as mx import numpy as np def read_worker(args, q_in): path_imgidx = os.path.join(args.input, "train.idx") path_imgrec = os.path.join(args.input, "train.rec") imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "r") s = imgrec.read_idx(0) header, _ = mx.recordio.unpack(s) assert header.flag > 0 imgidx = np.array(range(1, int(header.label[0]))) np.random.shuffle(imgidx) for idx in imgidx: item = imgrec.read_idx(idx) q_in.put(item) q_in.put(None) imgrec.close() def write_worker(args, q_out): pre_time = time.time() if args.input[-1] == '/': args.input = args.input[:-1] dirname = os.path.dirname(args.input) basename = os.path.basename(args.input) output = os.path.join(dirname, f"shuffled_{basename}") os.makedirs(output, exist_ok=True) path_imgidx = os.path.join(output, "train.idx") path_imgrec = os.path.join(output, "train.rec") save_record = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "w") more = True count = 0 while more: deq = q_out.get() if deq is None: more = False else: header, jpeg = mx.recordio.unpack(deq) # TODO it is currently not fully developed if isinstance(header.label, float): label = header.label else: label = header.label[0] header = mx.recordio.IRHeader(flag=header.flag, label=label, id=header.id, id2=header.id2) save_record.write_idx(count, mx.recordio.pack(header, jpeg)) count += 1 if count % 10000 == 0: cur_time = time.time() print('save time:', cur_time - pre_time, ' count:', count) pre_time = cur_time print(count) save_record.close() def main(args): queue = multiprocessing.Queue(10240) read_process = multiprocessing.Process(target=read_worker, args=(args, queue)) read_process.daemon = True read_process.start() write_process = multiprocessing.Process(target=write_worker, args=(args, queue)) write_process.start() write_process.join() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('input', help='path to source rec.') main(parser.parse_args())