from multiprocessing import shared_memory # import multiprocessing # if hasattr(multiprocessing, "shared_memory"): # from multiprocessing import shared_memory # else: # # workaround for single gpu inference on colab # shared_memory = None import random import pickle import time import copy import torch import torch.distributed as dist from lib.cfg_holder import cfg_unique_holder as cfguh def singleton(class_): instances = {} def getinstance(*args, **kwargs): if class_ not in instances: instances[class_] = class_(*args, **kwargs) return instances[class_] return getinstance def is_ddp(): return dist.is_available() and dist.is_initialized() def get_rank(type='local'): ddp = is_ddp() global_rank = dist.get_rank() if ddp else 0 local_world_size = torch.cuda.device_count() if type == 'global': return global_rank elif type == 'local': return global_rank % local_world_size elif type == 'node': return global_rank // local_world_size elif type == 'all': return global_rank, \ global_rank % local_world_size, \ global_rank // local_world_size else: assert False, 'Unknown type' def get_world_size(type='local'): ddp = is_ddp() global_rank = dist.get_rank() if ddp else 0 global_world_size = dist.get_world_size() if ddp else 1 local_world_size = torch.cuda.device_count() if type == 'global': return global_world_size elif type == 'local': return local_world_size elif type == 'node': return global_world_size // local_world_size elif type == 'all': return global_world_size, local_world_size, \ global_world_size // local_world_size else: assert False, 'Unknown type' class barrier_lock(object): def __init__(self, n): self.n = n id = int(random.random()*10000) + int(time.time())*10000 self.lock_shmname = 'barrier_lock_{}'.format(id) lock_shm = shared_memory.SharedMemory( name=self.lock_shmname, create=True, size=n) for i in range(n): lock_shm.buf[i] = 0 lock_shm.close() def destroy(self): try: lock_shm = shared_memory.SharedMemory( name=self.lock_shmname) lock_shm.close() lock_shm.unlink() except: return def wait(self, k): lock_shm = shared_memory.SharedMemory( name=self.lock_shmname) assert lock_shm.buf[k] == 0, 'Two waits on the same id is not allowed.' lock_shm.buf[k] = 1 if k == 0: while sum([lock_shm.buf[i]==0 for i in range(self.n)]) != 0: pass for i in range(self.n): lock_shm.buf[i] = 0 return else: while lock_shm.buf[k] != 0: pass class nodewise_sync_global(object): """ This is the global part of nodewise_sync that need to call at master process before spawn. """ def __init__(self): self.local_world_size = get_world_size('local') self.b_lock = barrier_lock(self.local_world_size) id = int(random.random()*10000) + int(time.time())*10000 self.id_shmname = 'nodewise_sync_id_shm_{}'.format(id) def destroy(self): self.b_lock.destroy() try: shm = shared_memory.SharedMemory(name=self.id_shmname) shm.close() shm.unlink() except: return @singleton class nodewise_sync(object): """ A class that centralize nodewise sync activities. The backend is multiprocess sharememory, not torch, as torch not support this. """ def __init__(self): pass def copy_global(self, reference): self.local_world_size = reference.local_world_size self.b_lock = reference.b_lock self.id_shmname = reference.id_shmname return self def local_init(self): self.ddp = is_ddp() self.global_rank, self.local_rank, self.node_rank = get_rank('all') self.global_world_size, self.local_world_size, self.nodes = get_world_size('all') if self.local_rank == 0: temp = int(random.random()*10000) + int(time.time())*10000 temp = pickle.dumps(temp) shm = shared_memory.SharedMemory( name=self.id_shmname, create=True, size=len(temp)) shm.close() return self def random_sync_id(self): assert self.local_rank is not None, 'Not initialized!' if self.local_rank == 0: sync_id = int(random.random()*10000) + int(time.time())*10000 data = pickle.dumps(sync_id) shm = shared_memory.SharedMemory(name=self.id_shmname) shm.buf[0:len(data)] = data[0:len(data)] self.barrier() shm.close() else: self.barrier() shm = shared_memory.SharedMemory(name=self.id_shmname) sync_id = pickle.loads(shm.buf) shm.close() return sync_id def barrier(self): self.b_lock.wait(self.local_rank) def broadcast_r0(self, data=None): assert self.local_rank is not None, 'Not initialized!' id = self.random_sync_id() shmname = 'broadcast_r0_{}'.format(id) if self.local_rank == 0: assert data!=None, 'Rank 0 needs to input data!' data = pickle.dumps(data) datan = len(data) load_info_shm = shared_memory.SharedMemory( name=shmname, create=True, size=datan) load_info_shm.buf[0:datan] = data[0:datan] self.barrier() self.barrier() load_info_shm.close() load_info_shm.unlink() return None else: assert data==None, 'Rank other than 1 should input None as data!' self.barrier() shm = shared_memory.SharedMemory(name=shmname) data = pickle.loads(shm.buf) shm.close() self.barrier() return data def destroy(self): self.barrier.destroy() try: shm = shared_memory.SharedMemory(name=self.id_shmname) shm.close() shm.unlink() except: return # import contextlib # @contextlib.contextmanager # def weight_sync(module, sync): # assert isinstance(module, torch.nn.Module) # if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): # yield # else: # with module.no_sync(): # yield # def weight_sync(net): # for parameters in net.parameters(): # dist.all_reduce(parameters, dist.ReduceOp.AVG)