efederici commited on
Commit
d25a04f
1 Parent(s): b1dbc68

Update meta_init_context.py

Browse files
Files changed (1) hide show
  1. meta_init_context.py +17 -12
meta_init_context.py CHANGED
@@ -1,4 +1,5 @@
1
  from contextlib import contextmanager
 
2
  import torch
3
  import torch.nn as nn
4
 
@@ -57,25 +58,29 @@ def init_on_device(device: torch.device, include_buffers: bool=False):
57
  if include_buffers:
58
  old_register_buffer = nn.Module.register_buffer
59
 
60
- def register_empty_parameter(module, name, param):
61
- old_register_parameter(module, name, param)
62
  if param is not None:
63
- param_cls = type(module._parameters[name])
64
- kwargs = module._parameters[name].__dict__
65
- module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
66
-
67
- def register_empty_buffer(module, name, buffer):
68
- old_register_buffer(module, name, buffer)
69
- if buffer is not None:
70
- module._buffers[name] = module._buffers[name].to(device)
 
 
 
 
71
  if include_buffers:
72
  tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']}
73
  else:
74
  tensor_constructors_to_patch = {}
75
 
76
- def patch_tensor_constructor(fn):
77
 
78
- def wrapper(*args, **kwargs):
79
  kwargs['device'] = device
80
  return fn(*args, **kwargs)
81
  return wrapper
 
1
  from contextlib import contextmanager
2
+ from typing import Any, Callable, Optional
3
  import torch
4
  import torch.nn as nn
5
 
 
58
  if include_buffers:
59
  old_register_buffer = nn.Module.register_buffer
60
 
61
+ def register_empty_parameter(self: torch.nn.Module, name: str, param: Optional[torch.nn.Parameter]):
62
+ old_register_parameter(self, name, param)
63
  if param is not None:
64
+ parameter = self._parameters[name]
65
+ assert parameter is not None
66
+ param_cls = type(parameter)
67
+ kwargs = parameter.__dict__
68
+ self._parameters[name] = param_cls(parameter.to(device), **kwargs)
69
+
70
+ def register_empty_buffer(self: torch.nn.Module, name: str, tensor: Optional[torch.Tensor], persistent: bool=True):
71
+ old_register_buffer(self, name, tensor, persistent=persistent)
72
+ if tensor is not None:
73
+ named_buffer = self._buffers[name]
74
+ assert named_buffer is not None
75
+ self._buffers[name] = named_buffer.to(device)
76
  if include_buffers:
77
  tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']}
78
  else:
79
  tensor_constructors_to_patch = {}
80
 
81
+ def patch_tensor_constructor(fn: Callable):
82
 
83
+ def wrapper(*args: Any, **kwargs: Any):
84
  kwargs['device'] = device
85
  return fn(*args, **kwargs)
86
  return wrapper