Spaces:
Runtime error
Runtime error
File size: 4,170 Bytes
53a077e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from geffnet import config
from geffnet.activations.activations_me import *
from geffnet.activations.activations_jit import *
from geffnet.activations.activations import *
import torch
_has_silu = 'silu' in dir(torch.nn.functional)
_ACT_FN_DEFAULT = dict(
silu=F.silu if _has_silu else swish,
swish=F.silu if _has_silu else swish,
mish=mish,
relu=F.relu,
relu6=F.relu6,
sigmoid=sigmoid,
tanh=tanh,
hard_sigmoid=hard_sigmoid,
hard_swish=hard_swish,
)
_ACT_FN_JIT = dict(
silu=F.silu if _has_silu else swish_jit,
swish=F.silu if _has_silu else swish_jit,
mish=mish_jit,
)
_ACT_FN_ME = dict(
silu=F.silu if _has_silu else swish_me,
swish=F.silu if _has_silu else swish_me,
mish=mish_me,
hard_swish=hard_swish_me,
hard_sigmoid_jit=hard_sigmoid_me,
)
_ACT_LAYER_DEFAULT = dict(
silu=nn.SiLU if _has_silu else Swish,
swish=nn.SiLU if _has_silu else Swish,
mish=Mish,
relu=nn.ReLU,
relu6=nn.ReLU6,
sigmoid=Sigmoid,
tanh=Tanh,
hard_sigmoid=HardSigmoid,
hard_swish=HardSwish,
)
_ACT_LAYER_JIT = dict(
silu=nn.SiLU if _has_silu else SwishJit,
swish=nn.SiLU if _has_silu else SwishJit,
mish=MishJit,
)
_ACT_LAYER_ME = dict(
silu=nn.SiLU if _has_silu else SwishMe,
swish=nn.SiLU if _has_silu else SwishMe,
mish=MishMe,
hard_swish=HardSwishMe,
hard_sigmoid=HardSigmoidMe
)
_OVERRIDE_FN = dict()
_OVERRIDE_LAYER = dict()
def add_override_act_fn(name, fn):
global _OVERRIDE_FN
_OVERRIDE_FN[name] = fn
def update_override_act_fn(overrides):
assert isinstance(overrides, dict)
global _OVERRIDE_FN
_OVERRIDE_FN.update(overrides)
def clear_override_act_fn():
global _OVERRIDE_FN
_OVERRIDE_FN = dict()
def add_override_act_layer(name, fn):
_OVERRIDE_LAYER[name] = fn
def update_override_act_layer(overrides):
assert isinstance(overrides, dict)
global _OVERRIDE_LAYER
_OVERRIDE_LAYER.update(overrides)
def clear_override_act_layer():
global _OVERRIDE_LAYER
_OVERRIDE_LAYER = dict()
def get_act_fn(name='relu'):
""" Activation Function Factory
Fetching activation fns by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config.
"""
if name in _OVERRIDE_FN:
return _OVERRIDE_FN[name]
use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit())
if use_me and name in _ACT_FN_ME:
# If not exporting or scripting the model, first look for a memory optimized version
# activation with custom autograd, then fallback to jit scripted, then a Python or Torch builtin
return _ACT_FN_ME[name]
if config.is_exportable() and name in ('silu', 'swish'):
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
return swish
use_jit = not (config.is_exportable() or config.is_no_jit())
# NOTE: export tracing should work with jit scripted components, but I keep running into issues
if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting
return _ACT_FN_JIT[name]
return _ACT_FN_DEFAULT[name]
def get_act_layer(name='relu'):
""" Activation Layer Factory
Fetching activation layers by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config.
"""
if name in _OVERRIDE_LAYER:
return _OVERRIDE_LAYER[name]
use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit())
if use_me and name in _ACT_LAYER_ME:
return _ACT_LAYER_ME[name]
if config.is_exportable() and name in ('silu', 'swish'):
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
return Swish
use_jit = not (config.is_exportable() or config.is_no_jit())
# NOTE: export tracing should work with jit scripted components, but I keep running into issues
if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting
return _ACT_LAYER_JIT[name]
return _ACT_LAYER_DEFAULT[name]
|