File size: 2,266 Bytes
6de3e11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import json
import mmap

import struct

from tqdm import tqdm


class DatasetWriter(object):
    def __init__(self, prefix):
        # 
        self.data_file = open(prefix + '.data', 'wb')
        self.header_file = open(prefix + '.header', 'wb')
        self.data_sum = 0
        self.offset = 0
        self.header = ''

    def add_data(self, data):
        key = str(self.data_sum)
        data = bytes(data, encoding="utf8")
        # 
        self.data_file.write(struct.pack('I', len(key)))
        self.data_file.write(key.encode('ascii'))
        self.data_file.write(struct.pack('I', len(data)))
        self.data_file.write(data)
        # 
        self.offset += 4 + len(key) + 4
        self.header = key + '\t' + str(self.offset) + '\t' + str(len(data)) + '\n'
        self.header_file.write(self.header.encode('ascii'))
        self.offset += len(data)
        self.data_sum += 1

    def close(self):
        self.data_file.close()
        self.header_file.close()


class DatasetReader(object):
    def __init__(self, data_header_path, min_duration=0, max_duration=30):
        self.keys = []
        self.offset_dict = {}
        self.fp = open(data_header_path.replace('.header', '.data'), 'rb')
        self.m = mmap.mmap(self.fp.fileno(), 0, access=mmap.ACCESS_READ)
        for line in tqdm(open(data_header_path, 'rb'), desc='读取数据列表'):
            key, val_pos, val_len = line.split('\t'.encode('ascii'))
            data = self.m[int(val_pos):int(val_pos) + int(val_len)]
            data = str(data, encoding="utf-8")
            data = json.loads(data)
            # 
            if data["duration"] < min_duration:
                continue
            if max_duration != -1 and data["duration"] > max_duration:
                continue
            self.keys.append(key)
            self.offset_dict[key] = (int(val_pos), int(val_len))

    # 
    def get_data(self, key):
        p = self.offset_dict.get(key, None)
        if p is None:
            return None
        val_pos, val_len = p
        data = self.m[val_pos:val_pos + val_len]
        data = str(data, encoding="utf-8")
        return json.loads(data)

    # 
    def get_keys(self):
        return self.keys

    def __len__(self):
        return len(self.keys)