| import glob |
| import io |
| import os |
| import re |
| import zipfile |
| from abc import ABC, abstractmethod |
| from contextlib import contextmanager |
| from dataclasses import dataclass |
| from typing import Dict, Iterator, List, Optional, Sequence, Tuple |
|
|
| import numpy as np |
|
|
|
|
| @dataclass |
| class NumpyArrayInfo: |
| """ |
| Information about an array in an npz file. |
| """ |
|
|
| name: str |
| dtype: np.dtype |
| shape: Tuple[int] |
|
|
| @classmethod |
| def infos_from_first_file(cls, glob_path: str) -> Dict[str, "NumpyArrayInfo"]: |
| paths, _ = _npz_paths_and_length(glob_path) |
| return cls.infos_from_file(paths[0]) |
|
|
| @classmethod |
| def infos_from_file(cls, npz_path: str) -> Dict[str, "NumpyArrayInfo"]: |
| """ |
| Extract the info of every array in an npz file. |
| """ |
| if not os.path.exists(npz_path): |
| raise FileNotFoundError(f"batch of samples was not found: {npz_path}") |
| results = {} |
| with open(npz_path, "rb") as f: |
| with zipfile.ZipFile(f, "r") as zip_f: |
| for name in zip_f.namelist(): |
| if not name.endswith(".npy"): |
| continue |
| key_name = name[: -len(".npy")] |
| with zip_f.open(name, "r") as arr_f: |
| version = np.lib.format.read_magic(arr_f) |
| if version == (1, 0): |
| header = np.lib.format.read_array_header_1_0(arr_f) |
| elif version == (2, 0): |
| header = np.lib.format.read_array_header_2_0(arr_f) |
| else: |
| raise ValueError(f"unknown numpy array version: {version}") |
| shape, _, dtype = header |
| results[key_name] = cls(name=key_name, dtype=dtype, shape=shape) |
| return results |
|
|
| @property |
| def elem_shape(self) -> Tuple[int]: |
| return self.shape[1:] |
|
|
| def validate(self): |
| if self.name in {"R", "G", "B"}: |
| if len(self.shape) != 2: |
| raise ValueError( |
| f"expecting exactly 2-D shape for '{self.name}' but got: {self.shape}" |
| ) |
| elif self.name == "arr_0": |
| if len(self.shape) < 2: |
| raise ValueError(f"expecting at least 2-D shape but got: {self.shape}") |
| elif len(self.shape) == 3: |
| |
| if not np.issubdtype(self.dtype, np.floating): |
| raise ValueError( |
| f"invalid dtype for audio batch: {self.dtype} (expected float)" |
| ) |
| elif self.dtype != np.uint8: |
| raise ValueError(f"invalid dtype for image batch: {self.dtype} (expected uint8)") |
|
|
|
|
| class NpzStreamer: |
| def __init__(self, glob_path: str): |
| self.paths, self.trunc_length = _npz_paths_and_length(glob_path) |
| self.infos = NumpyArrayInfo.infos_from_file(self.paths[0]) |
|
|
| def keys(self) -> List[str]: |
| return list(self.infos.keys()) |
|
|
| def stream(self, batch_size: int, keys: Sequence[str]) -> Iterator[Dict[str, np.ndarray]]: |
| cur_batch = None |
| num_remaining = self.trunc_length |
| for path in self.paths: |
| if num_remaining is not None and num_remaining <= 0: |
| break |
| with open_npz_arrays(path, keys) as readers: |
| combined_reader = CombinedReader(keys, readers) |
| while num_remaining is None or num_remaining > 0: |
| read_bs = batch_size |
| if cur_batch is not None: |
| read_bs -= _dict_batch_size(cur_batch) |
| if num_remaining is not None: |
| read_bs = min(read_bs, num_remaining) |
|
|
| batch = combined_reader.read_batch(read_bs) |
| if batch is None: |
| break |
| if num_remaining is not None: |
| num_remaining -= _dict_batch_size(batch) |
| if cur_batch is None: |
| cur_batch = batch |
| else: |
| cur_batch = { |
| |
| k: np.concatenate([cur_batch[k], v], axis=0) |
| for k, v in batch.items() |
| } |
| if _dict_batch_size(cur_batch) == batch_size: |
| yield cur_batch |
| cur_batch = None |
| if cur_batch is not None: |
| yield cur_batch |
|
|
|
|
| def _npz_paths_and_length(glob_path: str) -> Tuple[List[str], Optional[int]]: |
| |
| count_match = re.match("^(.*)\\[:([0-9]*)\\]$", glob_path) |
| if count_match: |
| raw_path = count_match[1] |
| max_count = int(count_match[2]) |
| else: |
| raw_path = glob_path |
| max_count = None |
| paths = sorted(glob.glob(raw_path)) |
| if not len(paths): |
| raise ValueError(f"no paths found matching: {glob_path}") |
| return paths, max_count |
|
|
|
|
| class NpzArrayReader(ABC): |
| @abstractmethod |
| def read_batch(self, batch_size: int) -> Optional[np.ndarray]: |
| pass |
|
|
|
|
| class StreamingNpzArrayReader(NpzArrayReader): |
| def __init__(self, arr_f, shape, dtype): |
| self.arr_f = arr_f |
| self.shape = shape |
| self.dtype = dtype |
| self.idx = 0 |
|
|
| def read_batch(self, batch_size: int) -> Optional[np.ndarray]: |
| if self.idx >= self.shape[0]: |
| return None |
|
|
| bs = min(batch_size, self.shape[0] - self.idx) |
| self.idx += bs |
|
|
| if self.dtype.itemsize == 0: |
| return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype) |
|
|
| read_count = bs * np.prod(self.shape[1:]) |
| read_size = int(read_count * self.dtype.itemsize) |
| data = _read_bytes(self.arr_f, read_size, "array data") |
| return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]]) |
|
|
|
|
| class MemoryNpzArrayReader(NpzArrayReader): |
| def __init__(self, arr): |
| self.arr = arr |
| self.idx = 0 |
|
|
| @classmethod |
| def load(cls, path: str, arr_name: str): |
| with open(path, "rb") as f: |
| arr = np.load(f)[arr_name] |
| return cls(arr) |
|
|
| def read_batch(self, batch_size: int) -> Optional[np.ndarray]: |
| if self.idx >= self.arr.shape[0]: |
| return None |
|
|
| res = self.arr[self.idx : self.idx + batch_size] |
| self.idx += batch_size |
| return res |
|
|
|
|
| @contextmanager |
| def open_npz_arrays(path: str, arr_names: Sequence[str]) -> List[NpzArrayReader]: |
| if not len(arr_names): |
| yield [] |
| return |
| arr_name = arr_names[0] |
| with open_array(path, arr_name) as arr_f: |
| version = np.lib.format.read_magic(arr_f) |
| header = None |
| if version == (1, 0): |
| header = np.lib.format.read_array_header_1_0(arr_f) |
| elif version == (2, 0): |
| header = np.lib.format.read_array_header_2_0(arr_f) |
|
|
| if header is None: |
| reader = MemoryNpzArrayReader.load(path, arr_name) |
| else: |
| shape, fortran, dtype = header |
| if fortran or dtype.hasobject: |
| reader = MemoryNpzArrayReader.load(path, arr_name) |
| else: |
| reader = StreamingNpzArrayReader(arr_f, shape, dtype) |
|
|
| with open_npz_arrays(path, arr_names[1:]) as next_readers: |
| yield [reader] + next_readers |
|
|
|
|
| class CombinedReader: |
| def __init__(self, keys: List[str], readers: List[NpzArrayReader]): |
| self.keys = keys |
| self.readers = readers |
|
|
| def read_batch(self, batch_size: int) -> Optional[Dict[str, np.ndarray]]: |
| batches = [r.read_batch(batch_size) for r in self.readers] |
| any_none = any(x is None for x in batches) |
| all_none = all(x is None for x in batches) |
| if any_none != all_none: |
| raise RuntimeError("different keys had different numbers of elements") |
| if any_none: |
| return None |
| if any(len(x) != len(batches[0]) for x in batches): |
| raise RuntimeError("different keys had different numbers of elements") |
| return dict(zip(self.keys, batches)) |
|
|
|
|
| def _read_bytes(fp, size, error_template="ran out of data"): |
| """ |
| Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886 |
| |
| Read from file-like object until size bytes are read. |
| Raises ValueError if not EOF is encountered before size bytes are read. |
| Non-blocking objects only supported if they derive from io objects. |
| Required as e.g. ZipExtFile in python 2.6 can return less data than |
| requested. |
| """ |
| data = bytes() |
| while True: |
| |
| |
| |
| try: |
| r = fp.read(size - len(data)) |
| data += r |
| if len(r) == 0 or len(data) == size: |
| break |
| except io.BlockingIOError: |
| pass |
| if len(data) != size: |
| msg = "EOF: reading %s, expected %d bytes got %d" |
| raise ValueError(msg % (error_template, size, len(data))) |
| else: |
| return data |
|
|
|
|
| @contextmanager |
| def open_array(path: str, arr_name: str): |
| with open(path, "rb") as f: |
| with zipfile.ZipFile(f, "r") as zip_f: |
| if f"{arr_name}.npy" not in zip_f.namelist(): |
| raise ValueError(f"missing {arr_name} in npz file") |
| with zip_f.open(f"{arr_name}.npy", "r") as arr_f: |
| yield arr_f |
|
|
|
|
| def _dict_batch_size(objs: Dict[str, np.ndarray]) -> int: |
| return len(next(iter(objs.values()))) |
|
|