Spaces:
Running
on
Zero
Running
on
Zero
# Description: This file contains the code for serializing the dataset. | |
# From https://github.com/ppwwyyxx/RAM-multiprocess-dataloader/blob/795868a37446d61412b9a58dbb1b7c76e75d39c4/serialize.py | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
""" | |
List serialization code adopted from | |
https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/common.py | |
""" | |
import multiprocessing as mp | |
from typing import List, Any, Optional | |
import pickle | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
import functools | |
import os | |
from datetime import timedelta | |
def get_world_size() -> int: | |
if not dist.is_available(): | |
return 1 | |
if not dist.is_initialized(): | |
return 1 | |
return dist.get_world_size() | |
def get_rank() -> int: | |
if not dist.is_available(): | |
return 0 | |
if not dist.is_initialized(): | |
return 0 | |
return dist.get_rank() | |
def get_local_rank() -> int: | |
if not dist.is_available(): | |
return 0 | |
if not dist.is_initialized(): | |
return 0 | |
# this is not guaranteed to be set | |
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: | |
return int(os.environ['LOCAL_RANK']) | |
elif 'SLURM_PROCID' in os.environ: | |
return int(os.environ['SLURM_LOCALID']) | |
else: | |
raise RuntimeError("Unable to get local rank") | |
def get_local_size() -> int: | |
return torch.cuda.device_count() | |
def _get_global_gloo_group(): | |
""" | |
Return a process group based on gloo backend, containing all the ranks | |
The result is cached. | |
""" | |
if dist.get_backend() == "nccl": | |
return dist.new_group(backend="gloo", timeout=timedelta(minutes=60)) | |
else: | |
return dist.group.WORLD | |
def all_gather(data, group=None): | |
""" | |
Run all_gather on arbitrary picklable data (not necessarily tensors). | |
Args: | |
data: any picklable object | |
group: a torch process group. By default, will use a group which | |
contains all ranks on gloo backend. | |
Returns: | |
list[data]: list of data gathered from each rank | |
""" | |
if get_world_size() == 1: | |
return [data] | |
if group is None: | |
group = ( | |
_get_global_gloo_group() | |
) # use CPU group by default, to reduce GPU RAM usage. | |
world_size = dist.get_world_size(group) | |
if world_size == 1: | |
return [data] | |
output = [None for _ in range(world_size)] | |
dist.all_gather_object(output, data, group=group) | |
return output | |
class NumpySerializedList: | |
def __init__(self, lst: list): | |
def _serialize(data): | |
buffer = pickle.dumps(data, protocol=-1) | |
return np.frombuffer(buffer, dtype=np.uint8) | |
print( | |
"Serializing {} elements to byte tensors and concatenating them all ...".format( | |
len(lst) | |
) | |
) | |
self._lst = [_serialize(x) for x in lst] | |
self._addr = np.asarray([len(x) for x in self._lst], dtype=np.int64) | |
self._addr = np.cumsum(self._addr) | |
self._lst = np.concatenate(self._lst) | |
print("Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024**2)) | |
def __len__(self): | |
return len(self._addr) | |
def __getitem__(self, idx): | |
start_addr = 0 if idx == 0 else self._addr[idx - 1].item() | |
end_addr = self._addr[idx].item() | |
bytes = memoryview(self._lst[start_addr:end_addr]) | |
return pickle.loads(bytes) | |
class TorchSerializedList(NumpySerializedList): | |
def __init__(self, lst: list): | |
super().__init__(lst) | |
self._addr = torch.from_numpy(self._addr) | |
self._lst = torch.from_numpy(self._lst) | |
def __getitem__(self, idx): | |
start_addr = 0 if idx == 0 else self._addr[idx - 1].item() | |
end_addr = self._addr[idx].item() | |
bytes = memoryview(self._lst[start_addr:end_addr].numpy()) | |
return pickle.loads(bytes) | |
def local_scatter(array: Optional[List[Any]]): | |
""" | |
Scatter an array from local leader to all local workers. | |
The i-th local worker gets array[i]. | |
Args: | |
array: Array with same size of #local workers. | |
""" | |
if get_local_size() <= 1: | |
# Just one worker. Do nothing. | |
return array[0] | |
if get_local_rank() == 0: | |
assert len(array) == get_local_size() | |
all_gather(array) | |
else: | |
all_data = all_gather(None) | |
array = all_data[get_rank() - get_local_rank()] | |
return array[get_local_rank()] | |
# NOTE: https://github.com/facebookresearch/mobile-vision/pull/120 | |
# has another implementation that does not use tensors. | |
class TorchShmSerializedList(TorchSerializedList): | |
def __init__(self, lst: list): | |
if get_local_rank() == 0: | |
super().__init__(lst) | |
if get_local_rank() == 0: | |
# Move data to shared memory, obtain a handle to send to each local worker. | |
# This is cheap because a tensor will only be moved to shared memory once. | |
handles = [None] + [ | |
bytes(mp.reduction.ForkingPickler.dumps((self._addr, self._lst))) | |
for _ in range(get_local_size() - 1) | |
] | |
else: | |
handles = None | |
# Each worker receives the handle from local leader. | |
handle = local_scatter(handles) | |
if get_local_rank() > 0: | |
# Materialize the tensor from shared memory. | |
self._addr, self._lst = mp.reduction.ForkingPickler.loads(handle) | |
print( | |
f"Worker {get_rank()} obtains a dataset of length=" | |
f"{len(self)} from its local leader." | |
) | |
# From https://github.com/ppwwyyxx/RAM-multiprocess-dataloader/issues/5#issuecomment-1510676170 | |
def local_broadcast_process_authkey(): | |
if int(os.environ['LOCAL_WORLD_SIZE']) == 1: | |
return | |
local_rank = int(os.environ['LOCAL_RANK']) | |
authkey = bytes(mp.current_process().authkey) | |
all_keys = all_gather(authkey) | |
local_leader_key = all_keys[get_rank() - local_rank] | |
if authkey != local_leader_key: | |
print("Process authkey is different from the key of local leader. This might happen when " | |
"workers are launched independently.") | |
print("Overwriting local authkey ...") | |
mp.current_process().authkey = local_leader_key | |