efederici commited on
Commit
6b4ce2c
1 Parent(s): e86fec8

Update param_init_fns.py

Browse files
Files changed (1) hide show
  1. param_init_fns.py +49 -51
param_init_fns.py CHANGED
@@ -2,22 +2,26 @@ import math
2
  import warnings
3
  from collections.abc import Sequence
4
  from functools import partial
5
- from typing import Optional, Tuple, Union
6
  import torch
7
  from torch import nn
 
8
  from .norm import NORM_CLASS_REGISTRY
 
 
 
 
9
 
10
- def torch_default_param_init_fn_(module: nn.Module, verbose: int=0, **kwargs):
11
  del kwargs
12
- if verbose > 1:
13
- warnings.warn(f"Initializing network using module's reset_parameters attribute")
14
- if hasattr(module, 'reset_parameters'):
15
  module.reset_parameters()
16
 
17
- def fused_init_helper_(module: nn.Module, init_fn_):
18
  _fused = getattr(module, '_fused', None)
19
  if _fused is None:
20
  raise RuntimeError(f'Internal logic error')
 
21
  (dim, splits) = _fused
22
  splits = (0, *splits, module.weight.size(dim))
23
  for (s, e) in zip(splits[:-1], splits[1:]):
@@ -25,10 +29,8 @@ def fused_init_helper_(module: nn.Module, init_fn_):
25
  slice_indices[dim] = slice(s, e)
26
  init_fn_(module.weight[slice_indices])
27
 
28
- def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
29
  del kwargs
30
- if verbose > 1:
31
- warnings.warn(f'If model has bias parameters they are initialized to 0.')
32
  init_div_is_residual = init_div_is_residual
33
  if init_div_is_residual is False:
34
  div_is_residual = 1.0
@@ -36,20 +38,18 @@ def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model:
36
  div_is_residual = math.sqrt(2 * n_layers)
37
  elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
38
  div_is_residual = init_div_is_residual
39
- elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric():
40
  div_is_residual = float(init_div_is_residual)
41
  else:
42
  div_is_residual = 1.0
43
  raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}')
44
- if init_div_is_residual is not False:
45
- if verbose > 1:
46
- warnings.warn(f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. ' + f'Set `init_div_is_residual: false` in init config to disable this.')
47
- if isinstance(module, nn.Linear):
48
  if hasattr(module, '_fused'):
49
  fused_init_helper_(module, init_fn_)
50
  else:
51
  init_fn_(module.weight)
52
  if module.bias is not None:
 
53
  torch.nn.init.zeros_(module.bias)
54
  if init_div_is_residual is not False and getattr(module, '_is_residual', False):
55
  with torch.no_grad():
@@ -60,8 +60,6 @@ def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model:
60
  if std == 0:
61
  warnings.warn(f'Embedding layer initialized to 0.')
62
  emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
63
- if verbose > 1:
64
- warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.')
65
  elif emb_init_uniform_lim is not None:
66
  lim = emb_init_uniform_lim
67
  if isinstance(lim, Sequence):
@@ -75,17 +73,13 @@ def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model:
75
  lim = [-lim, lim]
76
  (a, b) = lim
77
  emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
78
- if verbose > 1:
79
- warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.')
80
  else:
81
  emb_init_fn_ = init_fn_
82
  emb_init_fn_(module.weight)
83
  elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
84
- if verbose > 1:
85
- warnings.warn(f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.')
86
- if hasattr(module, 'weight') and module.weight is not None:
87
  torch.nn.init.ones_(module.weight)
88
- if hasattr(module, 'bias') and module.bias is not None:
89
  torch.nn.init.zeros_(module.bias)
90
  elif isinstance(module, nn.MultiheadAttention):
91
  if module._qkv_same_embed_dim:
@@ -114,32 +108,45 @@ def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model:
114
  module.out_proj.weight.div_(div_is_residual)
115
  if module.out_proj.bias is not None:
116
  torch.nn.init.zeros_(module.out_proj.bias)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  else:
118
  for _ in module.parameters(recurse=False):
119
  raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.')
120
 
121
- def _normal_init_(std, mean=0.0):
122
  return partial(torch.nn.init.normal_, mean=mean, std=std)
123
 
124
- def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
125
  del kwargs
126
  init_fn_ = _normal_init_(std=std)
127
- if verbose > 1:
128
- warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}')
129
- generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
130
 
131
- def baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
132
  del kwargs
133
  if init_std is None:
134
  raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
135
- _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
136
 
137
- def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
138
  del kwargs
139
  std = math.sqrt(2 / (5 * d_model))
140
- _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
141
 
142
- def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
143
  """From section 2.3.1 of GPT-NeoX-20B:
144
 
145
  An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
@@ -148,34 +155,25 @@ def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init
148
  """
149
  del kwargs
150
  residual_div = n_layers / math.sqrt(10)
151
- if verbose > 1:
152
- warnings.warn(f'setting init_div_is_residual to {residual_div}')
153
- small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
154
 
155
- def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
156
  del kwargs
157
- if verbose > 1:
158
- warnings.warn(f'Using nn.init.kaiming_uniform_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
159
  kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
160
- generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
161
 
162
- def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
163
  del kwargs
164
- if verbose > 1:
165
- warnings.warn(f'Using nn.init.kaiming_normal_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
166
  kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
167
- generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
168
 
169
- def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):
170
  del kwargs
171
  xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
172
- if verbose > 1:
173
- warnings.warn(f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' + f'gain={init_gain}')
174
- generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
175
 
176
- def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):
 
177
  xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
178
- if verbose > 1:
179
- warnings.warn(f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' + f'gain={init_gain}')
180
- generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
181
  MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}
 
2
  import warnings
3
  from collections.abc import Sequence
4
  from functools import partial
5
+ from typing import Any, Callable, Optional, Tuple, Union
6
  import torch
7
  from torch import nn
8
+ from .fc import FC_CLASS_REGISTRY
9
  from .norm import NORM_CLASS_REGISTRY
10
+ try:
11
+ import transformer_engine.pytorch as te
12
+ except:
13
+ te = None
14
 
15
+ def torch_default_param_init_fn_(module: nn.Module, **kwargs: Any) -> None:
16
  del kwargs
17
+ if hasattr(module, 'reset_parameters') and isinstance(module.reset_parameters, Callable):
 
 
18
  module.reset_parameters()
19
 
20
+ def fused_init_helper_(module: nn.Module, init_fn_: Callable) -> None:
21
  _fused = getattr(module, '_fused', None)
22
  if _fused is None:
23
  raise RuntimeError(f'Internal logic error')
24
+ assert isinstance(module.weight, torch.Tensor)
25
  (dim, splits) = _fused
26
  splits = (0, *splits, module.weight.size(dim))
27
  for (s, e) in zip(splits[:-1], splits[1:]):
 
29
  slice_indices[dim] = slice(s, e)
30
  init_fn_(module.weight[slice_indices])
31
 
32
+ def generic_param_init_fn_(module: nn.Module, init_fn_: Callable, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
33
  del kwargs
 
 
34
  init_div_is_residual = init_div_is_residual
35
  if init_div_is_residual is False:
36
  div_is_residual = 1.0
 
38
  div_is_residual = math.sqrt(2 * n_layers)
39
  elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
40
  div_is_residual = init_div_is_residual
41
+ elif init_div_is_residual.isnumeric():
42
  div_is_residual = float(init_div_is_residual)
43
  else:
44
  div_is_residual = 1.0
45
  raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}')
46
+ if isinstance(module, tuple(set(FC_CLASS_REGISTRY.values()))):
 
 
 
47
  if hasattr(module, '_fused'):
48
  fused_init_helper_(module, init_fn_)
49
  else:
50
  init_fn_(module.weight)
51
  if module.bias is not None:
52
+ assert isinstance(module.bias, torch.Tensor)
53
  torch.nn.init.zeros_(module.bias)
54
  if init_div_is_residual is not False and getattr(module, '_is_residual', False):
55
  with torch.no_grad():
 
60
  if std == 0:
61
  warnings.warn(f'Embedding layer initialized to 0.')
62
  emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
 
 
63
  elif emb_init_uniform_lim is not None:
64
  lim = emb_init_uniform_lim
65
  if isinstance(lim, Sequence):
 
73
  lim = [-lim, lim]
74
  (a, b) = lim
75
  emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
 
 
76
  else:
77
  emb_init_fn_ = init_fn_
78
  emb_init_fn_(module.weight)
79
  elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
80
+ if hasattr(module, 'weight') and isinstance(module.weight, torch.Tensor):
 
 
81
  torch.nn.init.ones_(module.weight)
82
+ if hasattr(module, 'bias') and isinstance(module.bias, torch.Tensor):
83
  torch.nn.init.zeros_(module.bias)
84
  elif isinstance(module, nn.MultiheadAttention):
85
  if module._qkv_same_embed_dim:
 
108
  module.out_proj.weight.div_(div_is_residual)
109
  if module.out_proj.bias is not None:
110
  torch.nn.init.zeros_(module.out_proj.bias)
111
+ elif te is not None and isinstance(module, te.LayerNormMLP):
112
+ if isinstance(module.layer_norm_weight, torch.Tensor):
113
+ torch.nn.init.ones_(module.layer_norm_weight)
114
+ if isinstance(module.layer_norm_bias, torch.Tensor):
115
+ torch.nn.init.zeros_(module.layer_norm_bias)
116
+ init_fn_(module.fc1_weight)
117
+ if module.fc1_bias is not None:
118
+ assert isinstance(module.fc1_bias, torch.Tensor)
119
+ torch.nn.init.zeros_(module.fc1_bias)
120
+ init_fn_(module.fc2_weight)
121
+ if module.fc2_bias is not None:
122
+ assert isinstance(module.fc2_bias, torch.Tensor)
123
+ torch.nn.init.zeros_(module.fc2_bias)
124
+ with torch.no_grad():
125
+ module.fc2_weight.div_(div_is_residual)
126
  else:
127
  for _ in module.parameters(recurse=False):
128
  raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.')
129
 
130
+ def _normal_init_(std: float, mean: float=0.0) -> Callable:
131
  return partial(torch.nn.init.normal_, mean=mean, std=std)
132
 
133
+ def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
134
  del kwargs
135
  init_fn_ = _normal_init_(std=std)
136
+ generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
 
 
137
 
138
+ def baseline_param_init_fn_(module: nn.Module, init_std: Optional[float], n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
139
  del kwargs
140
  if init_std is None:
141
  raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
142
+ _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
143
 
144
+ def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
145
  del kwargs
146
  std = math.sqrt(2 / (5 * d_model))
147
+ _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
148
 
149
+ def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
150
  """From section 2.3.1 of GPT-NeoX-20B:
151
 
152
  An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
 
155
  """
156
  del kwargs
157
  residual_div = n_layers / math.sqrt(10)
158
+ small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
 
 
159
 
160
+ def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', **kwargs: Any) -> None:
161
  del kwargs
 
 
162
  kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
163
+ generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
164
 
165
+ def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', **kwargs: Any) -> None:
166
  del kwargs
 
 
167
  kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
168
+ generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
169
 
170
+ def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, **kwargs: Any) -> None:
171
  del kwargs
172
  xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
173
+ generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
 
 
174
 
175
+ def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, **kwargs: Any) -> None:
176
+ del kwargs
177
  xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
178
+ generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
 
 
179
  MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}