| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from __future__ import absolute_import |
| | from __future__ import division |
| | from __future__ import print_function |
| |
|
| | import paddle |
| | import paddle.nn as nn |
| |
|
| | from typing import List |
| |
|
| |
|
| | def get_bn_running_state_names(model: nn.Layer) -> List[str]: |
| | """Get all bn state full names including running mean and variance |
| | """ |
| | names = [] |
| | for n, m in model.named_sublayers(): |
| | if isinstance(m, (nn.BatchNorm2D, nn.SyncBatchNorm)): |
| | assert hasattr(m, '_mean'), f'assert {m} has _mean' |
| | assert hasattr(m, '_variance'), f'assert {m} has _variance' |
| | running_mean = f'{n}._mean' |
| | running_var = f'{n}._variance' |
| | names.extend([running_mean, running_var]) |
| |
|
| | return names |
| |
|