from torch import nn FC_CLASS_REGISTRY = {'torch': nn.Linear} try: import transformer_engine.pytorch as te FC_CLASS_REGISTRY['te'] = te.Linear except: pass