Spaces:
Configuration error
Configuration error
import pickle | |
from copy import deepcopy | |
import numpy as np | |
class IndexedDataset: | |
def __init__(self, path, num_cache=1): | |
super().__init__() | |
self.path = path | |
self.data_file = None | |
self.data_offsets = np.load(f"{path}.idx", allow_pickle=True).item()['offsets'] | |
self.data_file = open(f"{path}.data", 'rb', buffering=-1) | |
self.cache = [] | |
self.num_cache = num_cache | |
def check_index(self, i): | |
if i < 0 or i >= len(self.data_offsets) - 1: | |
raise IndexError('index out of range') | |
def __del__(self): | |
if self.data_file: | |
self.data_file.close() | |
def __getitem__(self, i): | |
self.check_index(i) | |
if self.num_cache > 0: | |
for c in self.cache: | |
if c[0] == i: | |
return c[1] | |
self.data_file.seek(self.data_offsets[i]) | |
b = self.data_file.read(self.data_offsets[i + 1] - self.data_offsets[i]) | |
item = pickle.loads(b) | |
if self.num_cache > 0: | |
self.cache = [(i, deepcopy(item))] + self.cache[:-1] | |
return item | |
def __len__(self): | |
return len(self.data_offsets) - 1 | |
class IndexedDatasetBuilder: | |
def __init__(self, path): | |
self.path = path | |
self.out_file = open(f"{path}.data", 'wb') | |
self.byte_offsets = [0] | |
def add_item(self, item): | |
s = pickle.dumps(item) | |
bytes = self.out_file.write(s) | |
self.byte_offsets.append(self.byte_offsets[-1] + bytes) | |
def finalize(self): | |
self.out_file.close() | |
np.save(open(f"{self.path}.idx", 'wb'), {'offsets': self.byte_offsets}) | |
if __name__ == "__main__": | |
import random | |
from tqdm import tqdm | |
ds_path = '/tmp/indexed_ds_example' | |
size = 100 | |
items = [{"a": np.random.normal(size=[10000, 10]), | |
"b": np.random.normal(size=[10000, 10])} for i in range(size)] | |
builder = IndexedDatasetBuilder(ds_path) | |
for i in tqdm(range(size)): | |
builder.add_item(items[i]) | |
builder.finalize() | |
ds = IndexedDataset(ds_path) | |
for i in tqdm(range(10000)): | |
idx = random.randint(0, size - 1) | |
assert (ds[idx]['a'] == items[idx]['a']).all() | |