File size: 170 Bytes
c6f2274
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
from torch import nn

FC_CLASS_REGISTRY = {"torch": nn.Linear}
try:
    import transformer_engine.pytorch as te

    FC_CLASS_REGISTRY["te"] = te.Linear
except:
    pass