sam2ai's picture
Synced repo using 'sync_with_huggingface' Github Action
6de3e11
raw
history blame
2.27 kB
import json
import mmap
import struct
from tqdm import tqdm
class DatasetWriter(object):
def __init__(self, prefix):
#
self.data_file = open(prefix + '.data', 'wb')
self.header_file = open(prefix + '.header', 'wb')
self.data_sum = 0
self.offset = 0
self.header = ''
def add_data(self, data):
key = str(self.data_sum)
data = bytes(data, encoding="utf8")
#
self.data_file.write(struct.pack('I', len(key)))
self.data_file.write(key.encode('ascii'))
self.data_file.write(struct.pack('I', len(data)))
self.data_file.write(data)
#
self.offset += 4 + len(key) + 4
self.header = key + '\t' + str(self.offset) + '\t' + str(len(data)) + '\n'
self.header_file.write(self.header.encode('ascii'))
self.offset += len(data)
self.data_sum += 1
def close(self):
self.data_file.close()
self.header_file.close()
class DatasetReader(object):
def __init__(self, data_header_path, min_duration=0, max_duration=30):
self.keys = []
self.offset_dict = {}
self.fp = open(data_header_path.replace('.header', '.data'), 'rb')
self.m = mmap.mmap(self.fp.fileno(), 0, access=mmap.ACCESS_READ)
for line in tqdm(open(data_header_path, 'rb'), desc='读取数据列表'):
key, val_pos, val_len = line.split('\t'.encode('ascii'))
data = self.m[int(val_pos):int(val_pos) + int(val_len)]
data = str(data, encoding="utf-8")
data = json.loads(data)
#
if data["duration"] < min_duration:
continue
if max_duration != -1 and data["duration"] > max_duration:
continue
self.keys.append(key)
self.offset_dict[key] = (int(val_pos), int(val_len))
#
def get_data(self, key):
p = self.offset_dict.get(key, None)
if p is None:
return None
val_pos, val_len = p
data = self.m[val_pos:val_pos + val_len]
data = str(data, encoding="utf-8")
return json.loads(data)
#
def get_keys(self):
return self.keys
def __len__(self):
return len(self.keys)