|
from .modules import * |
|
from .parameter import ( |
|
Parameter as Parameter, |
|
UninitializedParameter as UninitializedParameter, |
|
UninitializedBuffer as UninitializedBuffer, |
|
) |
|
from .parallel import DataParallel as DataParallel |
|
from . import init |
|
from . import functional |
|
from . import utils |
|
|
|
|
|
def factory_kwargs(kwargs): |
|
r""" |
|
Given kwargs, returns a canonicalized dict of factory kwargs that can be directly passed |
|
to factory functions like torch.empty, or errors if unrecognized kwargs are present. |
|
|
|
This function makes it simple to write code like this:: |
|
|
|
class MyModule(nn.Module): |
|
def __init__(self, **kwargs): |
|
factory_kwargs = torch.nn.factory_kwargs(kwargs) |
|
self.weight = Parameter(torch.empty(10, **factory_kwargs)) |
|
|
|
Why should you use this function instead of just passing `kwargs` along directly? |
|
|
|
1. This function does error validation, so if there are unexpected kwargs we will |
|
immediately report an error, instead of deferring it to the factory call |
|
2. This function supports a special `factory_kwargs` argument, which can be used to |
|
explicitly specify a kwarg to be used for factory functions, in the event one of the |
|
factory kwargs conflicts with an already existing argument in the signature (e.g. |
|
in the signature ``def f(dtype, **kwargs)``, you can specify ``dtype`` for factory |
|
functions, as distinct from the dtype argument, by saying |
|
``f(dtype1, factory_kwargs={"dtype": dtype2})``) |
|
""" |
|
if kwargs is None: |
|
return {} |
|
simple_keys = {"device", "dtype", "memory_format"} |
|
expected_keys = simple_keys | {"factory_kwargs"} |
|
if not kwargs.keys() <= expected_keys: |
|
raise TypeError(f"unexpected kwargs {kwargs.keys() - expected_keys}") |
|
|
|
|
|
r = dict(kwargs.get("factory_kwargs", {})) |
|
for k in simple_keys: |
|
if k in kwargs: |
|
if k in r: |
|
raise TypeError(f"{k} specified twice, in **kwargs and in factory_kwargs") |
|
r[k] = kwargs[k] |
|
|
|
return r |
|
|