File size: 2,428 Bytes
03da825
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import argparse
import multiprocessing
import os
import time

import mxnet as mx
import numpy as np


def read_worker(args, q_in):
    path_imgidx = os.path.join(args.input, "train.idx")
    path_imgrec = os.path.join(args.input, "train.rec")
    imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "r")

    s = imgrec.read_idx(0)
    header, _ = mx.recordio.unpack(s)
    assert header.flag > 0

    imgidx = np.array(range(1, int(header.label[0])))
    np.random.shuffle(imgidx)
    
    for idx in imgidx:
        item = imgrec.read_idx(idx)
        q_in.put(item)

    q_in.put(None)
    imgrec.close()


def write_worker(args, q_out):
    pre_time = time.time()
    
    if args.input[-1] == '/':
        args.input = args.input[:-1]
    dirname = os.path.dirname(args.input)
    basename = os.path.basename(args.input)
    output = os.path.join(dirname, f"shuffled_{basename}")
    os.makedirs(output, exist_ok=True)
    
    path_imgidx = os.path.join(output, "train.idx")
    path_imgrec = os.path.join(output, "train.rec")
    save_record = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "w")
    more = True
    count = 0
    while more:
        deq = q_out.get()
        if deq is None:
            more = False
        else:
            header, jpeg = mx.recordio.unpack(deq)
            # TODO it is currently not fully developed
            if isinstance(header.label, float):
                label = header.label
            else:
                label = header.label[0]

            header = mx.recordio.IRHeader(flag=header.flag, label=label, id=header.id, id2=header.id2)
            save_record.write_idx(count, mx.recordio.pack(header, jpeg))
            count += 1
            if count % 10000 == 0:
                cur_time = time.time()
                print('save time:', cur_time - pre_time, ' count:', count)
                pre_time = cur_time
    print(count)
    save_record.close()


def main(args):
    queue = multiprocessing.Queue(10240)
    read_process = multiprocessing.Process(target=read_worker, args=(args, queue))
    read_process.daemon = True
    read_process.start()
    write_process = multiprocessing.Process(target=write_worker, args=(args, queue))
    write_process.start()
    write_process.join()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('input', help='path to source rec.')
    main(parser.parse_args())