Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
try: | |
from apex.normalization import FusedLayerNorm | |
except ImportError as e: | |
try: | |
from xformers.triton import FusedLayerNorm | |
except ImportError as e: | |
FusedLayerNorm = None | |
def replace_all_layernorms(model): | |
if FusedLayerNorm is None: | |
print("WARNING: apex.normalization & xformers.triton.FusedLayerNorm is not found, \ | |
skip using FusedLayerNorm") | |
return model | |
for name, module in model.named_children(): | |
if isinstance(module, torch.nn.LayerNorm): | |
setattr(model, name, FusedLayerNorm( | |
module.normalized_shape, module.eps, module.elementwise_affine)) | |
else: | |
replace_all_layernorms(module) | |
return model | |
def replace_all_groupnorms(model): | |
try: | |
from apex.contrib.group_norm import GroupNorm | |
except ImportError as e: | |
print("WARNING: apex.contrib.group_norm is not found, skip using apex groupnorm") | |
return model | |
for name, module in model.named_children(): | |
if isinstance(module, torch.nn.GroupNorm): | |
setattr(model, name, GroupNorm( | |
module.num_groups, module.num_channels, | |
eps=module.eps, affine=module.affine)) | |
else: | |
replace_all_groupnorms(module) | |
return model |