import pickle from copy import deepcopy import numpy as np class IndexedDataset: def __init__(self, path, num_cache=1): super().__init__() self.path = path self.data_file = None self.data_offsets = np.load(f"{path}.idx", allow_pickle=True).item()['offsets'] self.data_file = open(f"{path}.data", 'rb', buffering=-1) self.cache = [] self.num_cache = num_cache def check_index(self, i): if i < 0 or i >= len(self.data_offsets) - 1: raise IndexError('index out of range') def __del__(self): if self.data_file: self.data_file.close() def __getitem__(self, i): self.check_index(i) if self.num_cache > 0: for c in self.cache: if c[0] == i: return c[1] self.data_file.seek(self.data_offsets[i]) b = self.data_file.read(self.data_offsets[i + 1] - self.data_offsets[i]) item = pickle.loads(b) if self.num_cache > 0: self.cache = [(i, deepcopy(item))] + self.cache[:-1] return item def __len__(self): return len(self.data_offsets) - 1 class IndexedDatasetBuilder: def __init__(self, path): self.path = path self.out_file = open(f"{path}.data", 'wb') self.byte_offsets = [0] def add_item(self, item): s = pickle.dumps(item) bytes = self.out_file.write(s) self.byte_offsets.append(self.byte_offsets[-1] + bytes) def finalize(self): self.out_file.close() np.save(open(f"{self.path}.idx", 'wb'), {'offsets': self.byte_offsets}) if __name__ == "__main__": import random from tqdm import tqdm ds_path = '/tmp/indexed_ds_example' size = 100 items = [{"a": np.random.normal(size=[10000, 10]), "b": np.random.normal(size=[10000, 10])} for i in range(size)] builder = IndexedDatasetBuilder(ds_path) for i in tqdm(range(size)): builder.add_item(items[i]) builder.finalize() ds = IndexedDataset(ds_path) for i in tqdm(range(10000)): idx = random.randint(0, size - 1) assert (ds[idx]['a'] == items[idx]['a']).all()