File size: 4,175 Bytes
860b549 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import pytest
import torch
import torch.nn as nn
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn, Attention2d, MultiQueryAttentionV2
import importlib
import os
torch_backend = os.environ.get('TORCH_BACKEND')
if torch_backend is not None:
importlib.import_module(torch_backend)
torch_device = os.environ.get('TORCH_DEVICE', 'cpu')
class MLP(nn.Module):
def __init__(self, act_layer="relu", inplace=True):
super(MLP, self).__init__()
self.fc1 = nn.Linear(1000, 100)
self.act = create_act_layer(act_layer, inplace=inplace)
self.fc2 = nn.Linear(100, 10)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
def _run_act_layer_grad(act_type, inplace=True):
x = torch.rand(10, 1000) * 10
m = MLP(act_layer=act_type, inplace=inplace)
def _run(x, act_layer=''):
if act_layer:
# replace act layer if set
m.act = create_act_layer(act_layer, inplace=inplace)
out = m(x)
l = (out - 0).pow(2).sum()
return l
x = x.to(device=torch_device)
m.to(device=torch_device)
out_me = _run(x)
with set_layer_config(scriptable=True):
out_jit = _run(x, act_type)
assert torch.isclose(out_jit, out_me)
with set_layer_config(no_jit=True):
out_basic = _run(x, act_type)
assert torch.isclose(out_basic, out_jit)
def test_swish_grad():
for _ in range(100):
_run_act_layer_grad('swish')
def test_mish_grad():
for _ in range(100):
_run_act_layer_grad('mish')
def test_hard_sigmoid_grad():
for _ in range(100):
_run_act_layer_grad('hard_sigmoid', inplace=None)
def test_hard_swish_grad():
for _ in range(100):
_run_act_layer_grad('hard_swish')
def test_hard_mish_grad():
for _ in range(100):
_run_act_layer_grad('hard_mish')
def test_get_act_layer_empty_string():
# Empty string should return None
assert get_act_layer('') is None
def test_create_act_layer_inplace_error():
class NoInplaceAct(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
# Should recover when inplace arg causes TypeError
layer = create_act_layer(NoInplaceAct, inplace=True)
assert isinstance(layer, NoInplaceAct)
def test_create_act_layer_edge_cases():
# Test None input
assert create_act_layer(None) is None
# Test TypeError handling for inplace
class CustomAct(nn.Module):
def __init__(self, **kwargs):
super().__init__()
def forward(self, x):
return x
result = create_act_layer(CustomAct, inplace=True)
assert isinstance(result, CustomAct)
def test_get_act_fn_callable():
def custom_act(x):
return x
assert get_act_fn(custom_act) is custom_act
def test_get_act_fn_none():
assert get_act_fn(None) is None
assert get_act_fn('') is None
@pytest.mark.parametrize("dim", [128])
@pytest.mark.parametrize("dim_out", [128, 256])
@pytest.mark.parametrize("use_m", [True, False])
def test_mqa_v2(dim, dim_out, use_m):
mqa = MultiQueryAttentionV2(dim, dim_out)
x = torch.randn(1, dim, 32, 48)
if use_m:
m = torch.randn(1, dim, 16, 24)
else:
m = None
y = mqa(x, m=m)
assert (y.shape) == (1, dim_out, 32, 48)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("expand_first", [True, False])
@pytest.mark.parametrize("head_first", [True, False])
@pytest.mark.parametrize("attn_mask", [True, False])
def test_attn2d(bias, expand_first, head_first, attn_mask):
x = torch.randn(1, 128, 32, 48)
attn = Attention2d(
128, 128, num_heads=4, bias=bias, expand_first=expand_first, head_first=head_first
)
if attn_mask:
mask = torch.randint(0, 1, size=(32 * 48, 32 * 48), dtype=torch.float32)
else:
mask = None
o1 = attn(x, mask)
attn.fused_attn = False
o2 = attn(x, mask)
assert torch.allclose(o1, o2, atol=1e-5), f"{torch.abs(o1 - o2).max()}"
|