Spaces:
Running
on
L4
Running
on
L4
# -*- coding: utf-8 -*- | |
# File : replicate.py | |
# Author : Jiayuan Mao | |
# Email : maojiayuan@gmail.com | |
# Date : 27/01/2018 | |
# | |
# This file is part of Synchronized-BatchNorm-PyTorch. | |
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch | |
# Distributed under MIT License. | |
import functools | |
from torch.nn.parallel.data_parallel import DataParallel | |
__all__ = [ | |
'CallbackContext', | |
'execute_replication_callbacks', | |
'DataParallelWithCallback', | |
'patch_replication_callback' | |
] | |
class CallbackContext(object): | |
pass | |
def execute_replication_callbacks(modules): | |
""" | |
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. | |
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` | |
Note that, as all modules are isomorphism, we assign each sub-module with a context | |
(shared among multiple copies of this module on different devices). | |
Through this context, different copies can share some information. | |
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback | |
of any slave copies. | |
""" | |
master_copy = modules[0] | |
nr_modules = len(list(master_copy.modules())) | |
ctxs = [CallbackContext() for _ in range(nr_modules)] | |
for i, module in enumerate(modules): | |
for j, m in enumerate(module.modules()): | |
if hasattr(m, '__data_parallel_replicate__'): | |
m.__data_parallel_replicate__(ctxs[j], i) | |
class DataParallelWithCallback(DataParallel): | |
""" | |
Data Parallel with a replication callback. | |
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by | |
original `replicate` function. | |
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` | |
Examples: | |
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) | |
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) | |
# sync_bn.__data_parallel_replicate__ will be invoked. | |
""" | |
def replicate(self, module, device_ids): | |
modules = super(DataParallelWithCallback, self).replicate(module, device_ids) | |
execute_replication_callbacks(modules) | |
return modules | |
def patch_replication_callback(data_parallel): | |
""" | |
Monkey-patch an existing `DataParallel` object. Add the replication callback. | |
Useful when you have customized `DataParallel` implementation. | |
Examples: | |
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) | |
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) | |
> patch_replication_callback(sync_bn) | |
# this is equivalent to | |
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) | |
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) | |
""" | |
assert isinstance(data_parallel, DataParallel) | |
old_replicate = data_parallel.replicate | |
def new_replicate(module, device_ids): | |
modules = old_replicate(module, device_ids) | |
execute_replication_callbacks(modules) | |
return modules | |
data_parallel.replicate = new_replicate | |