27ccd05
1
2
3
4
5
6
7
from torch import nn FC_CLASS_REGISTRY = {'torch': nn.Linear} try: import transformer_engine.pytorch as te FC_CLASS_REGISTRY['te'] = te.Linear except: pass