Update meta_init_context.py
Browse files- meta_init_context.py +17 -12
meta_init_context.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from contextlib import contextmanager
|
|
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
|
@@ -57,25 +58,29 @@ def init_on_device(device: torch.device, include_buffers: bool=False):
|
|
57 |
if include_buffers:
|
58 |
old_register_buffer = nn.Module.register_buffer
|
59 |
|
60 |
-
def register_empty_parameter(
|
61 |
-
old_register_parameter(
|
62 |
if param is not None:
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
71 |
if include_buffers:
|
72 |
tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']}
|
73 |
else:
|
74 |
tensor_constructors_to_patch = {}
|
75 |
|
76 |
-
def patch_tensor_constructor(fn):
|
77 |
|
78 |
-
def wrapper(*args, **kwargs):
|
79 |
kwargs['device'] = device
|
80 |
return fn(*args, **kwargs)
|
81 |
return wrapper
|
|
|
1 |
from contextlib import contextmanager
|
2 |
+
from typing import Any, Callable, Optional
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
|
|
|
58 |
if include_buffers:
|
59 |
old_register_buffer = nn.Module.register_buffer
|
60 |
|
61 |
+
def register_empty_parameter(self: torch.nn.Module, name: str, param: Optional[torch.nn.Parameter]):
|
62 |
+
old_register_parameter(self, name, param)
|
63 |
if param is not None:
|
64 |
+
parameter = self._parameters[name]
|
65 |
+
assert parameter is not None
|
66 |
+
param_cls = type(parameter)
|
67 |
+
kwargs = parameter.__dict__
|
68 |
+
self._parameters[name] = param_cls(parameter.to(device), **kwargs)
|
69 |
+
|
70 |
+
def register_empty_buffer(self: torch.nn.Module, name: str, tensor: Optional[torch.Tensor], persistent: bool=True):
|
71 |
+
old_register_buffer(self, name, tensor, persistent=persistent)
|
72 |
+
if tensor is not None:
|
73 |
+
named_buffer = self._buffers[name]
|
74 |
+
assert named_buffer is not None
|
75 |
+
self._buffers[name] = named_buffer.to(device)
|
76 |
if include_buffers:
|
77 |
tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']}
|
78 |
else:
|
79 |
tensor_constructors_to_patch = {}
|
80 |
|
81 |
+
def patch_tensor_constructor(fn: Callable):
|
82 |
|
83 |
+
def wrapper(*args: Any, **kwargs: Any):
|
84 |
kwargs['device'] = device
|
85 |
return fn(*args, **kwargs)
|
86 |
return wrapper
|