julien-c's picture
julien-c HF staff
vendor cog_sdxl
3ab8901 verified
raw
history blame
4.03 kB
import contextlib
import contextvars
import threading
from typing import (
Callable,
ContextManager,
NamedTuple,
Optional,
TypeVar,
Union,
)
import torch
__all__ = ["no_init_or_tensor"]
Model = TypeVar("Model")
def no_init_or_tensor(
loading_code: Optional[Callable[..., Model]] = None
) -> Union[Model, ContextManager]:
"""
Suppress the initialization of weights while loading a model.
Can either directly be passed a callable containing model-loading code,
which will be evaluated with weight initialization suppressed,
or used as a context manager around arbitrary model-loading code.
Args:
loading_code: Either a callable to evaluate
with model weight initialization suppressed,
or None (the default) to use as a context manager.
Returns:
The return value of `loading_code`, if `loading_code` is callable.
Otherwise, if `loading_code` is None, returns a context manager
to be used in a `with`-statement.
Examples:
As a context manager::
from transformers import AutoConfig, AutoModelForCausalLM
config = AutoConfig("EleutherAI/gpt-j-6B")
with no_init_or_tensor():
model = AutoModelForCausalLM.from_config(config)
Or, directly passing a callable::
from transformers import AutoConfig, AutoModelForCausalLM
config = AutoConfig("EleutherAI/gpt-j-6B")
model = no_init_or_tensor(lambda: AutoModelForCausalLM.from_config(config))
"""
if loading_code is None:
return _NoInitOrTensorImpl.context_manager()
elif callable(loading_code):
with _NoInitOrTensorImpl.context_manager():
return loading_code()
else:
raise TypeError(
"no_init_or_tensor() expected a callable to evaluate,"
" or None if being used as a context manager;"
f' got an object of type "{type(loading_code).__name__}" instead.'
)
class _NoInitOrTensorImpl:
# Implementation of the thread-safe, async-safe, re-entrant context manager
# version of no_init_or_tensor().
# This class essentially acts as a namespace.
# It is not instantiable, because modifications to torch functions
# inherently affect the global scope, and thus there is no worthwhile data
# to store in the class instance scope.
_MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm)
_MODULE_ORIGINALS = tuple((m, m.reset_parameters) for m in _MODULES)
_ORIGINAL_EMPTY = torch.empty
is_active = contextvars.ContextVar("_NoInitOrTensorImpl.is_active", default=False)
_count_active: int = 0
_count_active_lock = threading.Lock()
@classmethod
@contextlib.contextmanager
def context_manager(cls):
if cls.is_active.get():
yield
return
with cls._count_active_lock:
cls._count_active += 1
if cls._count_active == 1:
for mod in cls._MODULES:
mod.reset_parameters = cls._disable(mod.reset_parameters)
# When torch.empty is called, make it map to meta device by replacing
# the device in kwargs.
torch.empty = cls._ORIGINAL_EMPTY
reset_token = cls.is_active.set(True)
try:
yield
finally:
cls.is_active.reset(reset_token)
with cls._count_active_lock:
cls._count_active -= 1
if cls._count_active == 0:
torch.empty = cls._ORIGINAL_EMPTY
for mod, original in cls._MODULE_ORIGINALS:
mod.reset_parameters = original
@staticmethod
def _disable(func):
def wrapper(*args, **kwargs):
# Behaves as normal except in an active context
if not _NoInitOrTensorImpl.is_active.get():
return func(*args, **kwargs)
return wrapper
__init__ = None