# Copyright (c) Facebook, Inc. and its affiliates. import io import numpy as np import os from dataclasses import dataclass from functools import reduce from operator import mul from typing import BinaryIO, Dict, Optional, Tuple import torch from detectron2.utils.comm import gather, get_rank from detectron2.utils.file_io import PathManager @dataclass class SizeData: dtype: str shape: Tuple[int] def _calculate_record_field_size_b(data_schema: Dict[str, SizeData], field_name: str) -> int: schema = data_schema[field_name] element_size_b = np.dtype(schema.dtype).itemsize record_field_size_b = reduce(mul, schema.shape) * element_size_b return record_field_size_b def _calculate_record_size_b(data_schema: Dict[str, SizeData]) -> int: record_size_b = 0 for field_name in data_schema: record_field_size_b = _calculate_record_field_size_b(data_schema, field_name) record_size_b += record_field_size_b return record_size_b def _calculate_record_field_sizes_b(data_schema: Dict[str, SizeData]) -> Dict[str, int]: field_sizes_b = {} for field_name in data_schema: field_sizes_b[field_name] = _calculate_record_field_size_b(data_schema, field_name) return field_sizes_b class SingleProcessTensorStorage: """ Compact tensor storage to keep tensor data of predefined size and type. """ def __init__(self, data_schema: Dict[str, SizeData], storage_impl: BinaryIO): """ Construct tensor storage based on information on data shape and size. Internally uses numpy to interpret the type specification. The storage must support operations `seek(offset, whence=os.SEEK_SET)` and `read(size)` to be able to perform the `get` operation. The storage must support operation `write(bytes)` to be able to perform the `put` operation. Args: data_schema (dict: str -> SizeData): dictionary which maps tensor name to its size data (shape and data type), e.g. ``` { "coarse_segm": SizeData(dtype="float32", shape=(112, 112)), "embedding": SizeData(dtype="float32", shape=(16, 112, 112)), } ``` storage_impl (BinaryIO): io instance that handles file-like seek, read and write operations, e.g. a file handle or a memory buffer like io.BytesIO """ self.data_schema = data_schema self.record_size_b = _calculate_record_size_b(data_schema) self.record_field_sizes_b = _calculate_record_field_sizes_b(data_schema) self.storage_impl = storage_impl self.next_record_id = 0 def get(self, record_id: int) -> Dict[str, torch.Tensor]: """ Load tensors from the storage by record ID Args: record_id (int): Record ID, for which to load the data Return: dict: str -> tensor: tensor name mapped to tensor data, recorded under the provided ID """ self.storage_impl.seek(record_id * self.record_size_b, os.SEEK_SET) data_bytes = self.storage_impl.read(self.record_size_b) assert len(data_bytes) == self.record_size_b, ( f"Expected data size {self.record_size_b} B could not be read: " f"got {len(data_bytes)} B" ) record = {} cur_idx = 0 # it's important to read and write in the same order for field_name in sorted(self.data_schema): schema = self.data_schema[field_name] field_size_b = self.record_field_sizes_b[field_name] chunk = data_bytes[cur_idx : cur_idx + field_size_b] data_np = np.frombuffer( chunk, dtype=schema.dtype, count=reduce(mul, schema.shape) ).reshape(schema.shape) record[field_name] = torch.from_numpy(data_np) cur_idx += field_size_b return record def put(self, data: Dict[str, torch.Tensor]) -> int: """ Store tensors in the storage Args: data (dict: str -> tensor): data to store, a dictionary which maps tensor names into tensors; tensor shapes must match those specified in data schema. Return: int: record ID, under which the data is stored """ # it's important to read and write in the same order for field_name in sorted(self.data_schema): assert ( field_name in data ), f"Field '{field_name}' not present in data: data keys are {data.keys()}" value = data[field_name] assert value.shape == self.data_schema[field_name].shape, ( f"Mismatched tensor shapes for field '{field_name}': " f"expected {self.data_schema[field_name].shape}, got {value.shape}" ) data_bytes = value.cpu().numpy().tobytes() assert len(data_bytes) == self.record_field_sizes_b[field_name], ( f"Expected field {field_name} to be of size " f"{self.record_field_sizes_b[field_name]} B, got {len(data_bytes)} B" ) self.storage_impl.write(data_bytes) record_id = self.next_record_id self.next_record_id += 1 return record_id class SingleProcessFileTensorStorage(SingleProcessTensorStorage): """ Implementation of a single process tensor storage which stores data in a file """ def __init__(self, data_schema: Dict[str, SizeData], fpath: str, mode: str): self.fpath = fpath assert "b" in mode, f"Tensor storage should be opened in binary mode, got '{mode}'" if "w" in mode: # pyre-fixme[6]: For 2nd argument expected `Union[typing_extensions.Liter... file_h = PathManager.open(fpath, mode) elif "r" in mode: local_fpath = PathManager.get_local_path(fpath) file_h = open(local_fpath, mode) else: raise ValueError(f"Unsupported file mode {mode}, supported modes: rb, wb") super().__init__(data_schema, file_h) # pyre-ignore[6] class SingleProcessRamTensorStorage(SingleProcessTensorStorage): """ Implementation of a single process tensor storage which stores data in RAM """ def __init__(self, data_schema: Dict[str, SizeData], buf: io.BytesIO): super().__init__(data_schema, buf) class MultiProcessTensorStorage: """ Representation of a set of tensor storages created by individual processes, allows to access those storages from a single owner process. The storages should either be shared or broadcasted to the owner process. The processes are identified by their rank, data is uniquely defined by the rank of the process and the record ID. """ def __init__(self, rank_to_storage: Dict[int, SingleProcessTensorStorage]): self.rank_to_storage = rank_to_storage def get(self, rank: int, record_id: int) -> Dict[str, torch.Tensor]: storage = self.rank_to_storage[rank] return storage.get(record_id) def put(self, rank: int, data: Dict[str, torch.Tensor]) -> int: storage = self.rank_to_storage[rank] return storage.put(data) class MultiProcessFileTensorStorage(MultiProcessTensorStorage): def __init__(self, data_schema: Dict[str, SizeData], rank_to_fpath: Dict[int, str], mode: str): rank_to_storage = { rank: SingleProcessFileTensorStorage(data_schema, fpath, mode) for rank, fpath in rank_to_fpath.items() } super().__init__(rank_to_storage) # pyre-ignore[6] class MultiProcessRamTensorStorage(MultiProcessTensorStorage): def __init__(self, data_schema: Dict[str, SizeData], rank_to_buffer: Dict[int, io.BytesIO]): rank_to_storage = { rank: SingleProcessRamTensorStorage(data_schema, buf) for rank, buf in rank_to_buffer.items() } super().__init__(rank_to_storage) # pyre-ignore[6] def _ram_storage_gather( storage: SingleProcessRamTensorStorage, dst_rank: int = 0 ) -> Optional[MultiProcessRamTensorStorage]: storage.storage_impl.seek(0, os.SEEK_SET) # TODO: overhead, pickling a bytes object, can just pass bytes in a tensor directly # see detectron2/utils.comm.py data_list = gather(storage.storage_impl.read(), dst=dst_rank) if get_rank() != dst_rank: return None rank_to_buffer = {i: io.BytesIO(data_list[i]) for i in range(len(data_list))} multiprocess_storage = MultiProcessRamTensorStorage(storage.data_schema, rank_to_buffer) return multiprocess_storage def _file_storage_gather( storage: SingleProcessFileTensorStorage, dst_rank: int = 0, mode: str = "rb", ) -> Optional[MultiProcessFileTensorStorage]: storage.storage_impl.close() fpath_list = gather(storage.fpath, dst=dst_rank) if get_rank() != dst_rank: return None rank_to_fpath = {i: fpath_list[i] for i in range(len(fpath_list))} return MultiProcessFileTensorStorage(storage.data_schema, rank_to_fpath, mode) def storage_gather( storage: SingleProcessTensorStorage, dst_rank: int = 0 ) -> Optional[MultiProcessTensorStorage]: if isinstance(storage, SingleProcessRamTensorStorage): return _ram_storage_gather(storage, dst_rank) elif isinstance(storage, SingleProcessFileTensorStorage): return _file_storage_gather(storage, dst_rank) raise Exception(f"Unsupported storage for gather operation: {storage}")