tuandunghcmut's picture
Upload folder using huggingface_hub
345ee20 verified
raw
history blame
5.05 kB
import contextlib
import functools
import itertools
import torch
from . import utils, wrap
__all__ = ['half_function', 'float_function',
'register_half_function', 'register_float_function',
'register_float_module', 'init', 'reset']
_DECORATOR_HANDLE = None
_USER_CAST_REGISTRY = set()
_USER_FLOAT_MODULE = set()
_ORIGINAL_MODULE_HALF = None
def _decorator_helper(orig_fn, cast_fn, wrap_fn):
def wrapper(*args, **kwargs):
handle = _DECORATOR_HANDLE
if handle is None or not handle.is_active():
return orig_fn(*args, **kwargs)
inner_cast_fn = utils.verbosify(cast_fn, orig_fn.__name__,
handle.verbose)
return wrap_fn(orig_fn, inner_cast_fn, handle)(*args, **kwargs)
return wrapper
# Decorator form
def half_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)
return _decorator_helper(fn, utils.maybe_half, wrap_fn)
def float_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False)
return _decorator_helper(fn, utils.maybe_float, wrap_fn)
# Registry form
def register_half_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_half))
def register_float_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_float))
def register_float_module(module, cast_args=True):
if not issubclass(module, torch.nn.modules.module.Module):
raise ValueError('{} is not a torch Module'.format(module))
if cast_args:
register_float_function(module, 'forward')
_USER_FLOAT_MODULE.add(module)
class AmpHandle(object):
def __init__(self, enable_caching=True, verbose=False):
self._enable_caching = enable_caching
self._verbose = verbose
self._cache = dict()
self._is_active = True
self._all_wrappers = []
def is_active(self):
return self._is_active
@contextlib.contextmanager
def _disable_casts(self):
self._is_active = False
yield
self._is_active = True
def _clear_cache(self):
self._cache.clear()
# Experimental support for saving / restoring uncasted versions of functions
def _save_func(self, mod, fn, func):
self._all_wrappers.append((mod, fn, func))
def _deactivate(self):
for mod, fn, func in self._all_wrappers:
utils.set_func(mod, fn, func)
self._all_wrappers = []
@property
def has_cache(self):
return self._enable_caching
@property
def cache(self):
return self._cache
def remove_cache(self, param):
if self.has_cache and param in self.cache:
del self.cache[param]
@property
def verbose(self):
return self._verbose
def _half_helper(verbose=False):
def _half_wrapper(self):
for module in self.children():
module.half()
if self.__class__ in _USER_FLOAT_MODULE:
if verbose:
print('Skip half convert for {}'.format(self.__class__))
return self
fn = lambda t: t.half() if t.is_floating_point() else t
for param in self._parameters.values():
if param is not None:
# Tensors stored in modules are graph leaves, and we don't
# want to create copy nodes, so we have to unpack the data.
param.data = fn(param.data)
if param._grad is not None:
param._grad.data = fn(param._grad.data)
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
return self
return _half_wrapper
def init(enable_caching=True, verbose=False):
global _DECORATOR_HANDLE
global _ORIGINAL_MODULE_HALF
handle = AmpHandle(enable_caching, verbose)
if len(_USER_FLOAT_MODULE) > 0:
_ORIGINAL_MODULE_HALF = torch.nn.modules.module.Module.half
utils.set_func(torch.nn.modules.module.Module, 'half',
_half_helper(verbose))
# Force-{fp16, fp32} for user-annotated functions
for mod, fn, cast_fn in _USER_CAST_REGISTRY:
try_caching = (cast_fn == utils.maybe_half)
wrap.cached_cast(mod, fn, cast_fn, handle,
try_caching, verbose)
_USER_CAST_REGISTRY.clear()
_DECORATOR_HANDLE = handle
return handle
def _clear_cache():
handle = _DECORATOR_HANDLE
if handle is None or not handle.is_active():
return
handle._clear_cache()
def reset():
handle = _DECORATOR_HANDLE
if handle is None or not handle.is_active():
return
handle._deactivate()
utils.set_func(torch.nn.modules.module.Module, 'half', _ORIGINAL_MODULE_HALF)