Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
from torch._functorch.utils import exposed_in | |
def batch_norm_without_running_stats(module: nn.Module): | |
if isinstance(module, nn.modules.batchnorm._BatchNorm) and module.track_running_stats: | |
module.running_mean = None | |
module.running_var = None | |
module.num_batches_tracked = None | |
module.track_running_stats = False | |
def replace_all_batch_norm_modules_(root: nn.Module) -> nn.Module: | |
""" | |
In place updates :attr:`root` by setting the ``running_mean`` and ``running_var`` to be None and | |
setting track_running_stats to be False for any nn.BatchNorm module in :attr:`root` | |
""" | |
# base case | |
batch_norm_without_running_stats(root) | |
for obj in root.modules(): | |
batch_norm_without_running_stats(obj) | |
return root | |